Source code for jacinle.random.rng

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : rng.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import os
import contextlib
import random as sys_random

import numpy as np
import numpy.random as npr

from typing import Optional
from jacinle.utils.defaults import defaults_manager
from jacinle.utils.registry import Registry

__all__ = ['JacRandomState', 'get_default_rng', 'gen_seed', 'gen_rng', 'reset_global_seed', 'with_seed', 'seed']


[docs] class JacRandomState(npr.RandomState):
[docs] def choice_list(self, list_, size=1, replace=False, p=None): """Efficiently draw an element from an list, if the rng is given, use it instead of the system one.""" if size == 1: if type(list_) in (list, tuple): return list_[self.choice(len(list_), p=p)] return self.choice(list_, p=p) else: if type(list_) in (list, tuple): inds = self.choice(len(list_), size=size, replace=replace, p=p) return [list_[i] for i in inds] return self.choice(list_, size=size, replace=replace, p=p)
[docs] def shuffle_list(self, list_): if type(list_) is list: sys_random.shuffle(list_, random=self.random_sample) else: self.shuffle(list_)
[docs] def shuffle_multi(self, *arrs): length = len(arrs[0]) for a in arrs: assert len(a) == length, 'non-compatible length when shuffling multiple arrays' inds = np.arange(length) self.shuffle(inds) return tuple(map(lambda x: x[inds], arrs))
[docs] @defaults_manager.wrap_custom_as_default(is_local=True) def as_default(self): yield self
_rng = JacRandomState() _get_default_rng = defaults_manager.gen_get_default(JacRandomState, default_getter=lambda: _rng)
[docs] def get_default_rng() -> JacRandomState: """Get the default random number generator.""" return _get_default_rng()
[docs] def gen_seed() -> int: """Generate a random seed (a.k.a. a random integer in [0, 2^32)).""" return get_default_rng().randint(4294967296)
[docs] def gen_rng(seed: Optional[int] = None) -> JacRandomState: """Generate a random number generator with the given seed.""" return JacRandomState(seed)
global_rng_registry = Registry() global_rng_registry.register('jacinle', lambda: _rng.seed) global_rng_registry.register('numpy', lambda: npr.seed) global_rng_registry.register('sys', lambda: sys_random.seed) global_rng_state_registry = Registry() global_rng_state_registry.register('jacinle', lambda: (_rng.get_state, _rng.set_state)) global_rng_state_registry.register('numpy', lambda: (npr.get_state, npr.set_state)) global_rng_state_registry.register('sys', lambda: (sys_random.getstate, sys_random.setstate))
[docs] def reset_global_seed(seed: Optional[int] = None, verbose: bool = False) -> int: """Reset the global seed for all random number generators. Args: seed: the seed to use. If None, a random seed will be generated. verbose: whether to print the seed. Returns: the seed used. """ if seed is None: seed = gen_seed() for k, seed_getter in global_rng_registry.items(): if verbose: from jacinle.logging import get_logger logger = get_logger(__file__) logger.critical('Reset random seed for: {} (pid={}, seed={}).'.format(k, os.getpid(), seed)) seed_getter()(seed) return seed
[docs] def seed(seed): reset_global_seed(seed)
[docs] @contextlib.contextmanager def with_seed(seed: Optional[int] = None, verbose: bool = False): """A context manager that sets the global seed to the given value, and restores it after the context. Note that when the given seed is None, this function will not do anything. Example: .. code-block:: python with with_seed(123): # do something Args: seed: the seed to set. """ if seed is None: yield return states = dict() setter_functions = dict() for k in global_rng_state_registry.keys(): if global_rng_registry.has(k) and global_rng_state_registry.has(k): seed_func = global_rng_registry.lookup(k)() state_getter, state_setter = global_rng_state_registry.lookup(k)() states[k] = state_getter() setter_functions[k] = state_setter seed_func(seed) if verbose: from jacinle.logging import get_logger logger = get_logger(__file__) logger.critical('Reset random seed for: {} (pid={}, seed={}).'.format(k, os.getpid(), seed)) yield for k, state in states.items(): setter_functions[k](state)
def _initialize_global_seed(): seed = os.environ.get('JACINLE_GLOBAL_SEED', None) if seed is not None: seed = int(seed) reset_global_seed(seed, verbose=True) _initialize_global_seed()