Source code for jaclearn.dataflow.collections

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : collections.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import numpy as np

from jacinle.utils.argument import UniqueValueGetter

from .dataflow import SimpleDataFlowBase, ProxyDataFlowBase, RandomizedDataFlowBase

__all__ = [
    'RandomIndexDataFlow',
    'DictDataFlowProxy', 'EmptyDictDataFlow',
    'QueueDataFlow', 'PoolDataFlow',
    'ListOfArrayDataFlow', 'DictOfArrayDataFlow',
    'DictToBatchDataFlow',
    'KVStoreDataFlow', 'KVStoreRandomSampleDataFlow',
    'PoolRandomSampleDataFlow', 'LOARandomSampleDataFlow', 'DOARandomSampleDataFlow'
]


[docs] class RandomIndexDataFlow(RandomizedDataFlowBase):
[docs] def __init__(self, source, nr_repeat=1, seed=None): super().__init__(seed=seed) self._source = source self._nr_repeat = nr_repeat
def _len(self): return len(self._source) def _gen(self): if self._nr_repeat is None: while True: yield from self._gen_once() else: for i in range(self._nr_repeat): yield from self._gen_once() def _gen_once(self): length = len(self._source) indices = np.arange(length) self._rng.shuffle(indices) for i in range(length): yield self._source[indices[i]]
[docs] class DictDataFlowProxy(ProxyDataFlowBase):
[docs] def __init__(self, keys, iterable): super().__init__(iterable) self._keys = keys self._iterable = iterable
def _gen(self): for v in self._iterable: assert len(self._keys) == len(v), 'DictDataFlowAdapter: length mismatched' yield dict(zip(self._keys, v))
[docs] class EmptyDictDataFlow(SimpleDataFlowBase): def _gen(self): while True: yield {}
[docs] class QueueDataFlow(SimpleDataFlowBase):
[docs] def __init__(self, queue): self._queue = queue
def _gen(self): while True: yield self._queue.get()
[docs] class PoolDataFlow(SimpleDataFlowBase):
[docs] def __init__(self, pool): self._pool = pool self._length = len(self._pool)
def _gen(self): for i in range(self._length): yield self._pool[i] def _len(self): return self._length
[docs] class ListOfArrayDataFlow(SimpleDataFlowBase):
[docs] def __init__(self, loa): self._loa = loa uvg = UniqueValueGetter('ListOfArrayDataFlow length consistency check failed') for i in self._loa: uvg.set(len(i)) self._length = uvg.get()
def _gen(self): for i in range(self._length): yield [l[i] for l in self._loa] def _len(self): return self._length
[docs] def DictOfArrayDataFlow(doa): keys = doa.keys() values = [doa[k] for k in keys] return DictDataFlowProxy(keys, ListOfArrayDataFlow(values))
[docs] class DictToBatchDataFlow(ProxyDataFlowBase):
[docs] def __init__(self, iterable, excludes=None): super().__init__(iterable) self._excludes = set(excludes) if excludes is not None else set()
def _gen(self): for item in self.unwrapped: for k, v in item.items(): if k not in self._excludes: item[k] = np.array(v)[np.newaxis]
[docs] class KVStoreDataFlow(SimpleDataFlowBase):
[docs] def __init__(self, kv_getter): self._kv_getter = kv_getter self._kvstore = None self._keys = None
def _initialize(self): super()._initialize() self._kvstore = self._kv_getter() self._keys = list(self._kvstore.keys()) def _gen(self): for k in self._keys: yield self._kvstore.get(k)
[docs] class KVStoreRandomSampleDataFlow(RandomizedDataFlowBase):
[docs] def __init__(self, kv_getter, seed=None): super().__init__(seed=seed) self._kv_getter = kv_getter self._kvstore = None self._keys = None self._nr_keys = None
def _initialize(self): super()._initialize() self._kvstore = self._kv_getter() self._keys = list(self._kvstore.keys()) self._nr_keys = len(self._keys) def _gen(self): while True: k = self._keys[self._rng.choice(self._nr_keys)] yield self._kvstore.get(k)
[docs] class PoolRandomSampleDataFlow(RandomizedDataFlowBase): _pool = None
[docs] def __init__(self, pool, seed=None): super().__init__(seed=seed) self._pool = pool self._length = len(self._pool)
def _gen(self): while True: self._rng.shuffle_list(self._pool) for i in range(self._length): yield self._pool[i]
[docs] class LOARandomSampleDataFlow(RandomizedDataFlowBase): _loa = None _length = None
[docs] def __init__(self, loa, seed=None): super().__init__(seed=seed) self._set_loa(loa)
def _set_loa(self, loa): self._loa = loa uvg = UniqueValueGetter('LOARandomSampleDataFlow length consistency check failed') for i in self._loa: uvg.set(len(i)) self._length = uvg.get() def _gen(self): while True: state = self._rng.get_state() for item in self._loa: self._rng.set_state(state) self._rng.shuffle(item) for i in range(self._length): yield [l[i] for l in self._loa]
[docs] def DOARandomSampleDataFlow(doa, seed=None): keys = doa.keys() values = [doa[k] for k in keys] return DictDataFlowProxy(keys, LOARandomSampleDataFlow(values, seed=seed))
[docs] class RandomRepeatDataFlow(RandomizedDataFlowBase):
[docs] def __init__(self, source, nr_repeat, cache_size, block=False, seed=None): super().__init__(seed=seed) self._source = source self._nr_repeat = nr_repeat self._cache_size = cache_size self._block = block
def _gen(self): it = iter(self._source) while True: data = [] for i in range(self._cache_size): d = next(it) data.append(d) if not self._block: yield d nr_repeat = self._nr_repeat if self._block else self._nr_repeat - 1 for i in range(nr_repeat * self._cache_size): idx = self._rng.randint(len(data)) yield data[idx]