Source code for jaclearn.data.sampler

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : sampler.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

from jacinle.random.rng import gen_rng
from jacinle.nd import index_select_batch

__all__ = ['EpochBatchSampler', 'SimpleBatchSampler']


class _SizedGenerator(object):
    def __init__(self, generator, length):
        self._generator = generator
        self._length = length

    def __iter__(self):
        for i in self._generator:
            yield i

    def __len__(self):
        return self._length


[docs] class RenamedDictSamplerBase(object): def _gen_renamed(self, data, keys, renames=None): if renames is None: return self._gen(data, keys) assert len(renames) == len(keys) for v in self._gen(data, keys): yield {k1: v[k2] for k1, k2 in zip(renames, keys)} def _gen(self, data, keys): raise NotImplementedError()
[docs] class EpochBatchSampler(RenamedDictSamplerBase):
[docs] def __init__(self, batch_size, epoch_size, rng=None): self._batch_size = batch_size self._epoch_size = epoch_size self._rng = rng or gen_rng()
def _gen(self, data, keys): n = len(data[keys[0]]) for i in range(self._epoch_size): this_idx = self._rng.randint(n, size=self._batch_size) this = {k: index_select_batch(data[k], this_idx) for k in keys} yield this
[docs] def __call__(self, data, keys, renames=None): return _SizedGenerator(self._gen_renamed(data, keys, renames), self._epoch_size)
[docs] class SimpleBatchSampler(RenamedDictSamplerBase):
[docs] def __init__(self, batch_size, nr_repeat, rng=None): self._batch_size = batch_size self._nr_repeat = nr_repeat self._rng = rng or gen_rng()
def _gen(self, data, keys): n = len(data[keys[0]]) for i in range(self._nr_repeat): idx = self._rng.permutation(n) for j in range(n // self._batch_size): this_idx = idx[j * self._batch_size:j * self._batch_size + self._batch_size] this = {k: index_select_batch(data[k], this_idx) for k in keys} yield this def _len(self, data, keys): n = len(data[keys[0]]) return self._nr_repeat * (n // self._batch_size)
[docs] def __call__(self, data, keys, renames=None): return _SizedGenerator(self._gen_renamed(data, keys, renames), self._len(data, keys))