Source code for jaclearn.dataflow.batch
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : batch.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
from copy import deepcopy
from threading import Thread, Event
import traceback
from jacinle.concurrency.event import MTBooleanEvent
from jacinle.logging import get_logger
from jacinle.utils.meta import gofor
from .dataflow import SimpleDataFlowBase
logger = get_logger(__file__)
__all__ = ['BatchDataFlow', 'EpochDataFlow']
[docs]
def batch_default_filler(buffer, idx, val):
for k, v in gofor(val):
if k in buffer:
buffer[k][idx] = v
[docs]
class BatchDataFlow(SimpleDataFlowBase):
_buffer = None
_cond = None
_filler_thread = None
_stop_event = None
[docs]
def __init__(self, source, batch_size, sample_dict, filler=batch_default_filler):
super().__init__()
self._source = source
self._batch_size = batch_size
self._sample_dict = sample_dict
self._filler = filler
def _initialize(self):
self._initialize_buffer()
self._initialize_filler()
def _initialize_buffer(self):
self._buffer = [deepcopy(self._sample_dict) for _ in range(2)]
def _initialize_filler(self):
self._cond = [MTBooleanEvent() for _ in range(2)]
self._stop_event = Event()
self._filler_thread = Thread(target=self._filler_mainloop, name=str(self) + ':filler', daemon=True)
self._filler_thread.start()
def _filler_mainloop(self):
current = 0
it = iter(self._source)
try:
while True:
self._cond[current].wait_false()
for i in range(self._batch_size):
self._filler(self._buffer[current], i, next(it))
self._cond[current].set_true()
current = 1 - current
except Exception as e:
logger.warn('{} got exception {} in filler thread: {}.'.format(type(self), type(e), e))
traceback.print_exc()
self._cond[current].set_true()
self._stop_event.set()
def _gen(self):
current = 0
while True:
self._cond[current].wait_true()
if self._stop_event.is_set():
return
yield self._buffer[current]
self._cond[current].set_false()
current = 1 - current
def _len(self):
length = len(self._source)
return None if length is None else length // self._batch_size
[docs]
class EpochDataFlow(SimpleDataFlowBase):
[docs]
def __init__(self, source, epoch_size):
self._source = source
self._source_iter = None
self._epoch_size = epoch_size
def _initialize(self):
self._source_iter = iter(self._source)
def _gen(self):
for i in range(self._epoch_size):
try:
yield next(self._source_iter)
except StopIteration:
return
def _len(self):
return self._epoch_size