Source code for jacinle.utils.defaults

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : defaults.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/28/2017
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import threading
import inspect
import contextlib
import functools

from .meta import decorator_with_optional_args
from .inspect import class_name_of_method

__all__ = [
    'defaults_manager', 'wrap_custom_as_default', 'gen_get_default', 'gen_set_default',
    'option_context', 'FileOptions',
    'ARGDEF', 'default_args'
]


[docs] class DefaultsManager(object): """Defaults manager can be used to create program or thread-level registries. One of the typical use case is that you can create an instance of a specific class, and then set it as the default, and then get this instance from elsewhere. For example:: >>> class Storage(object): ... def __init__(self, value): ... self.value = value >>> storage = Storage(1) >>> set_defualt_storage(storage) >>> get_default_storage() # now you can call this elsewhere. Another important feature supported by this default manager is that it allows you to have "nested" default registries. For example:: >>> get_default_storage().value # -> 1 >>> with Stoage(2).as_default(): ... get_default_storage().value # -> 2 ... with Storage(3).as_default(): ... get_default_storage().value # -> 3 ... get_default_storage().value # -> 2 Similar features have been used commonly in TensorFlow, e.g., tf.Session, tf.Graph. To create a class with a default registry, use the following: .. code-block:: python class Storage(object): def __init__(self, value): self.value = value @defaults_manager.wrap_custom_as_default(is_local=True) def as_default(self): # this is a contextmanager yield get_default_storage = defaults_manager.gen_get_default(Storage) set_default_storage = defaults_manager.gen_set_default(Storage) """
[docs] def __init__(self): self._is_local = dict() self._defaults_global = dict() self._defaults_local = threading.local()
[docs] @decorator_with_optional_args(is_method=True) def wrap_custom_as_default(self, *, is_local=False): def wrapper(meth): identifier = class_name_of_method(meth) meth = contextlib.contextmanager(meth) self._is_local[identifier] = is_local defaults = self._get_defaults_registry(identifier) @contextlib.contextmanager @functools.wraps(meth) def wrapped_func(slf, *args, **kwargs): backup = defaults.get(identifier, None) defaults[identifier] = slf with meth(slf, *args, **kwargs): yield defaults[identifier] = backup return wrapped_func return wrapper
[docs] def gen_get_default(self, cls, default_getter=None): identifier = class_name_of_method(cls.as_default) def get_default(default=None): if default is None and default_getter is not None: default = default_getter() # NB(Jiayuan Mao): cannot use .get(identifier, default), because after calling as_default, the current # default will be set to None. val = self._get_defaults_registry(identifier).get(identifier, None) if val is None: val = default return val return get_default
[docs] def gen_set_default(self, cls): identifier = class_name_of_method(cls.as_default) def set_default(default): self._get_defaults_registry(identifier)[identifier] = default return set_default
[docs] def set_default(self, cls, default): identifier = class_name_of_method(cls.as_default) self._get_defaults_registry(identifier)[identifier] = default
def _get_defaults_registry(self, identifier): is_local = self._is_local.get(identifier, False) if is_local: if not hasattr(self._defaults_local, 'defaults'): self._defaults_local.defaults = dict() defaults = self._defaults_local.defaults else: defaults = self._defaults_global return defaults
defaults_manager = DefaultsManager() wrap_custom_as_default = defaults_manager.wrap_custom_as_default gen_get_default = defaults_manager.gen_get_default gen_set_default = defaults_manager.gen_set_default class _LocalObjectSimulator(object): __slots__ = ['ctx']
[docs] def option_context(name, is_local=True, **kwargs): class OptionContext(object): def __init__(self, **init_kwargs): for k, v in kwargs.items(): setattr(self, k, v) if hasattr(self.__class__, 'current_context') and self.__class__.current_context.ctx is not None: c = self.__class__.get_default() for k in kwargs: setattr(self, k, getattr(c, k)) for k, v in init_kwargs.items(): assert k in kwargs setattr(self, k, v) @classmethod def get_option(cls, name): """Get the option value of the current context.""" getattr(cls.get_default(), name) @classmethod def set_default_option(cls, name, value): """Set the option value for the current context.""" cls._create_default_context() setattr(cls.default_context.ctx, name, value) @classmethod def get_default(cls): """Get the current option context.""" cls._create_current_context() if cls.current_context.ctx is not None: return cls.current_context.ctx else: cls._create_default_context() return cls.default_context.ctx @contextlib.contextmanager def as_default(self): """Make this option context the current context. It will overwrite the current option values.""" self.__class__._create_current_context() backup = self.__class__.current_context.ctx self.__class__.current_context.ctx = self yield self.__class__.current_context.ctx = backup @classmethod def _create_default_context(cls): if hasattr(cls, 'default_context'): return if is_local: cls.default_context = threading.local() else: cls.default_context = _LocalObjectSimulator() cls.default_context.ctx = cls(**kwargs) @classmethod def _create_current_context(cls): if hasattr(cls, 'current_context'): return if is_local: cls.current_context = threading.local() else: cls.current_context = _LocalObjectSimulator() cls.current_context.ctx = None OptionContext.__name__ = name return OptionContext
[docs] class FileOptions(object): """A class that stores options in a single file. Example: .. code-block:: python # file: my_module.py options = FileOptions(__file__, number_to_add=1) def my_func(x: int) -> int: return x + options.number_to_add # file: my_script.py import my_module my_module.options.set(number_to_add=2) my_module.my_func(1) # returns 3 """
[docs] def __init__(self, __file__, **init_kwargs): self.__file__ = __file__ for k, v in init_kwargs.items(): setattr(self, k, v)
[docs] def set(self, **kwargs): for k, v in kwargs.items(): assert hasattr(self, k), '{} is not an option for file "{}".'.format(k, self.__file__) setattr(self, k, v)
ARGDEF = object() """A special value to indicate that the default value of an argument will be determined in a deferred manner. See :func:`default_args`."""
[docs] def default_args(func): """A helper function handles the case of "fall-through" default arguments. Suppose we have two functions: ``f`` and ``g``, and ``f`` calls ``g``. ``g`` has a default argument ``x``, e.g., ``x=1``. In many cases, we do not want to specify the default value of ``x`` in ``f``. One way to do this is to use ``None`` as the default value of ``x`` in ``f``, and then check if ``x`` is ``None`` in ``g``. However this does not handle cases where ``x`` can be ``None`` in other cases. It also requires additional checks in ``g``. With this decorator, we can simply write ``x=ARGDEF`` in ``f``, and then ``x`` will be set to ``1`` in ``g``. Example: .. code-block:: python def f(x=ARGDEF): g(x) @default_args def g(x=1): print(x) f() # prints 1 f(2) # prints 2 """ def wrapper(func): sig = inspect.signature(func) @functools.wraps(func) def wrapped(*args, **kwargs): bounded = sig.bind(*args, **kwargs) bounded.apply_defaults() for k, v in bounded.arguments.items(): if v is ARGDEF: if k in sig.parameters: default_value = sig.parameters[k].default bounded.arguments[k] = default_value return func(*bounded.args, **bounded.kwargs) return wrapped return wrapper(func)