Source code for jacinle.config.environ

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : environ.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import contextlib
from copy import deepcopy
from typing import Any

from jacinle.utils.meta import dict_deep_keys, dict_deep_update


__all__ = ['env', 'load_env', 'has_env', 'get_env', 'set_env', 'with_env']


[docs] class Environ(object): """A global environment object.""" __env_ext__ = '.env.pkl'
[docs] def __init__(self, envs=None): self.envs = dict() if envs is not None: self.load(envs)
[docs] def load(self, env_spec, incremental=False): new_envs = self.__get_envs_from_spec(env_spec) if incremental: dict_deep_update(self.envs, new_envs) else: self.envs = deepcopy(new_envs) return self
[docs] def update(self, env_spec): return self.load(env_spec, incremental=True)
[docs] def dump(self, path, prefix=None): raise NotImplementedError('Not supported yet: "Environ.dump".')
[docs] def as_dict(self): return deepcopy(self.envs)
[docs] def as_dict_ref(self): return self.envs
[docs] def clone(self): new_env = Environ() new_env.envs = deepcopy(self.envs) return new_env
[docs] def keys(self, is_flattened=True): if is_flattened: return dict_deep_keys(self.envs) return list(self.envs.keys())
[docs] def has(self, key: str) -> bool: """Check whether a key is in current env object. Args: key: the key. Returns: True if the key is in current env object, otherwise False. """ return self.get(key, None) is not None
[docs] def get(self, key, default=None): """Get a value of a environment provided a key. You can provide a default value, but this value will not affect the env object. Args: key: the key. Dict of dict can (should) be imploded by ``.``. default: the default value. Returns: The value if the env contains the given key, otherwise the default value provided. """ subkeys = key.split('.') current = self.envs for subkey in subkeys[0:-1]: if subkey not in current: current[subkey] = dict() current = current[subkey] if subkeys[-1] in current: return current[subkeys[-1]] elif default is None: return default else: current[subkeys[-1]] = default return default
[docs] def set(self, key: str, value: Any = None, do_inc: bool = False, do_replace: bool = True, inc_default: Any = 0): """Set an environment value by key-value pair. Args: key: the key, note that dict of dict can (should) be imploded by ``.''. value: the value. do_inc: whether to increase the value. do_replace: whether to replace the value if the key already exists. inc_default: the default value of the accumulator. Returns: self """ subkeys = key.split('.') current = self.envs for subkey in subkeys[0:-1]: if subkey not in current: current[subkey] = dict() current = current[subkey] if do_inc: if subkeys[-1] not in current: current[subkeys[-1]] = inc_default current[subkeys[-1]] += value elif do_replace or subkeys[-1] not in current: current[subkeys[-1]] = value return self
[docs] def set_default(self, key: str, default: Any = None): """Set an environment value by key-value pair. If the key already exists, it will not be overwritten. Args: key: the key, note that dict of dict can (should) be imploded by ``.''. default: the default value. Returns: self :param key: the key, note that dict of dict can (should) be imploded by ``.''. :param default: the ``default'' value. :return: self """ self.set(key, default, do_replace=False)
[docs] def inc(self, key: str, inc: Any = 1, default: Any = 0): """Increase the environment value provided a key. Args: key: the key. inc: the increment. default: the default value. Returns: self """ self.set(key, inc, do_inc=True, inc_default=default) return self
def __contains__(self, item): return self.has(item) def __getitem__(self, item): return self.get(item, None) def __setitem__(self, key, value): self.set(key, value) return value def __get_envs_from_spec(self, env_spec): if isinstance(env_spec, str) and env_spec.endswith(self.__env_ext__): raise NotImplementedError('Not implemented loading method.') elif isinstance(env_spec, dict): return env_spec elif isinstance(env_spec, object) and (hasattr(env_spec, 'envs') or hasattr(env_spec, '__envs__')): return getattr(env_spec, 'envs', None) or getattr(env_spec, '__envs__') else: raise TypeError('unsupported env spec: {}.'.format(env_spec))
env = Environ() load_env = env.load update_env = env.update has_env = env.has get_env = env.get set_env = env.set
[docs] @contextlib.contextmanager def with_env(env_spec, incremental=True): if not incremental: backup = env.as_dict_ref() else: backup = env.as_dict() env.load(env_spec, incremental=incremental) yield env.envs = backup