Source code for jaclearn.dataflow.dataflow
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : dataflow.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import collections
from jacinle.logging import get_logger
from jacinle.random import gen_rng
logger = get_logger(__file__)
__all__ = ['DataFlowBase', 'SimpleDataFlowBase', 'ProxyDataFlowBase', 'AdvancedDataFlowBase', 'RandomizedDataFlowBase']
[docs]
class DataFlowBase(object):
pass
collections.abc.Iterator.register(DataFlowBase)
[docs]
class SimpleDataFlowBase(DataFlowBase):
__initialized = False
def _initialize(self):
pass
def _reset(self):
pass
def _gen(self):
raise NotImplementedError()
def _finalize(self):
pass
def _len(self):
return None
def __len__(self):
try:
return self._len()
except TypeError:
return None
def __iter__(self):
if not self.__initialized:
self._initialize()
self.__initialized = True
self._reset()
try:
for v in self._gen():
yield v
except Exception as e:
logger.exception('{} got exception {} during iter: {}.'.format(type(self), type(e), e))
finally:
self._finalize()
[docs]
class ProxyDataFlowBase(SimpleDataFlowBase):
[docs]
def __init__(self, other):
self._unwrapped = other
@property
def unwrapped(self):
return self._unwrapped
def _gen(self):
for item in self._unwrapped:
yield item
def _len(self):
return len(self._unwrapped)
[docs]
class AdvancedDataFlowBase(DataFlowBase):
[docs]
def __init__(self):
self._is_first_iter = True
def __len__(self):
return self._count()
def __iter__(self):
self._initialize()
self._is_first_iter = True
return self
def __next__(self):
if not self._is_first_iter:
if self._have_next():
self._move_next()
else:
self._finalize()
raise StopIteration()
else:
self._is_first_iter = False
result = self._get()
return result
def _initialize(self):
raise NotImplementedError()
def _finalize(self):
pass
def _get(self):
raise NotImplementedError()
def _count(self):
raise NotImplementedError()
def _move_next(self):
raise NotImplementedError()
def _have_next(self):
raise NotImplementedError()
[docs]
class RandomizedDataFlowBase(SimpleDataFlowBase):
_rng = None
[docs]
def __init__(self, seed=None):
self._seed = seed
def _initialize(self):
self._rng = gen_rng(seed=self._seed)