Source code for jaclearn.dataflow.utils
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : utils.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import queue
import threading
from itertools import repeat, cycle as cached_cycle
from itertools import takewhile, dropwhile, filterfalse
from itertools import chain
from itertools import starmap
from itertools import islice
from itertools import tee
from jacinle.utils.meta import map_exec
from .dataflow import SimpleDataFlowBase, ProxyDataFlowBase
__all__ = [
'cycle', 'cycle_n', 'cached_cycle', 'repeat', 'repeat_n',
'chain',
'takewhile', 'dropwhile',
'filter', 'filtertrue', 'filterfalse',
'map', 'starmap', 'ssmap',
'islice', 'truncate',
'tee',
'MapDataFlow', 'DataFlowMixer'
]
map = map
filter = filter
filtertrue = filter
repeat_n = repeat
truncate = islice
# implement cycle self, without any cache
[docs]
def cycle(iterable, times=None):
if times is None:
while True:
for v in iterable:
yield v
else:
for i in range(times):
for v in iterable:
yield v
cycle_n = cycle
[docs]
def ssmap(function, iterable):
for args in iterable:
yield function(**args)
[docs]
class MapDataFlow(ProxyDataFlowBase):
[docs]
def __init__(self, other, map_func=None):
super().__init__(other)
self.__map_func = map_func
def _map(self, data):
return self.__map_func(data)
def _gen(self):
for data in self.unwrapped:
yield self._map(data)
[docs]
class DataFlowMixer(SimpleDataFlowBase):
[docs]
def __init__(self, dataflows, buflen=None):
if buflen is None:
buflen = len(dataflows)
self._dataflows = dataflows
self._queue = queue.Queue(maxsize=buflen)
self._stop_signal = threading.Event()
self._comsumers = []
def _initialize(self):
self._consumers = [
threading.Thread(target=self._consumer, args=(ind, df), daemon=True)
for ind, df in enumerate(self._dataflows)]
map_exec(threading.Thread.start, self._consumers)
def _finalize(self):
self._stop_signal.set()
while True:
try:
self._queue.get_nowait()
except queue.Empty:
break
map_exec(threading.Thread.join, self._consumers)
def _consumer(self, ind, df):
for data in df:
self._queue.put(self._wrapper(data, ind))
if self._stop_signal.is_set():
break
def _gen(self):
while True:
yield self._queue.get()
def _wrapper(self, data, ind):
return data, ind