#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : environ.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import contextlib
from copy import deepcopy
from typing import Any
from jacinle.utils.meta import dict_deep_keys, dict_deep_update
__all__ = ['env', 'load_env', 'has_env', 'get_env', 'set_env', 'with_env']
[docs]
class Environ(object):
"""A global environment object."""
__env_ext__ = '.env.pkl'
[docs]
def __init__(self, envs=None):
self.envs = dict()
if envs is not None:
self.load(envs)
[docs]
def load(self, env_spec, incremental=False):
new_envs = self.__get_envs_from_spec(env_spec)
if incremental:
dict_deep_update(self.envs, new_envs)
else:
self.envs = deepcopy(new_envs)
return self
[docs]
def update(self, env_spec):
return self.load(env_spec, incremental=True)
[docs]
def dump(self, path, prefix=None):
raise NotImplementedError('Not supported yet: "Environ.dump".')
[docs]
def as_dict(self):
return deepcopy(self.envs)
[docs]
def as_dict_ref(self):
return self.envs
[docs]
def clone(self):
new_env = Environ()
new_env.envs = deepcopy(self.envs)
return new_env
[docs]
def keys(self, is_flattened=True):
if is_flattened:
return dict_deep_keys(self.envs)
return list(self.envs.keys())
[docs]
def has(self, key: str) -> bool:
"""Check whether a key is in current env object.
Args:
key: the key.
Returns:
True if the key is in current env object, otherwise False.
"""
return self.get(key, None) is not None
[docs]
def get(self, key, default=None):
"""Get a value of a environment provided a key. You can provide a default value, but this value will not affect the env object.
Args:
key: the key. Dict of dict can (should) be imploded by ``.``.
default: the default value.
Returns:
The value if the env contains the given key, otherwise the default value provided.
"""
subkeys = key.split('.')
current = self.envs
for subkey in subkeys[0:-1]:
if subkey not in current:
current[subkey] = dict()
current = current[subkey]
if subkeys[-1] in current:
return current[subkeys[-1]]
elif default is None:
return default
else:
current[subkeys[-1]] = default
return default
[docs]
def set(self, key: str, value: Any = None, do_inc: bool = False, do_replace: bool = True, inc_default: Any = 0):
"""Set an environment value by key-value pair.
Args:
key: the key, note that dict of dict can (should) be imploded by ``.''.
value: the value.
do_inc: whether to increase the value.
do_replace: whether to replace the value if the key already exists.
inc_default: the default value of the accumulator.
Returns:
self
"""
subkeys = key.split('.')
current = self.envs
for subkey in subkeys[0:-1]:
if subkey not in current:
current[subkey] = dict()
current = current[subkey]
if do_inc:
if subkeys[-1] not in current:
current[subkeys[-1]] = inc_default
current[subkeys[-1]] += value
elif do_replace or subkeys[-1] not in current:
current[subkeys[-1]] = value
return self
[docs]
def set_default(self, key: str, default: Any = None):
"""Set an environment value by key-value pair. If the key already exists, it will not be overwritten.
Args:
key: the key, note that dict of dict can (should) be imploded by ``.''.
default: the default value.
Returns:
self
:param key: the key, note that dict of dict can (should) be imploded by ``.''.
:param default: the ``default'' value.
:return: self
"""
self.set(key, default, do_replace=False)
[docs]
def inc(self, key: str, inc: Any = 1, default: Any = 0):
"""Increase the environment value provided a key.
Args:
key: the key.
inc: the increment.
default: the default value.
Returns:
self
"""
self.set(key, inc, do_inc=True, inc_default=default)
return self
def __contains__(self, item):
return self.has(item)
def __getitem__(self, item):
return self.get(item, None)
def __setitem__(self, key, value):
self.set(key, value)
return value
def __get_envs_from_spec(self, env_spec):
if isinstance(env_spec, str) and env_spec.endswith(self.__env_ext__):
raise NotImplementedError('Not implemented loading method.')
elif isinstance(env_spec, dict):
return env_spec
elif isinstance(env_spec, object) and (hasattr(env_spec, 'envs') or hasattr(env_spec, '__envs__')):
return getattr(env_spec, 'envs', None) or getattr(env_spec, '__envs__')
else:
raise TypeError('unsupported env spec: {}.'.format(env_spec))
env = Environ()
load_env = env.load
update_env = env.update
has_env = env.has
get_env = env.get
set_env = env.set
[docs]
@contextlib.contextmanager
def with_env(env_spec, incremental=True):
if not incremental:
backup = env.as_dict_ref()
else:
backup = env.as_dict()
env.load(env_spec, incremental=incremental)
yield
env.envs = backup