Source code for jactorch.data.dataloader.dataloader_torch030

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : dataloader_torch030.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/01/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import threading
import queue
import multiprocessing

try:
    from torch.utils.data.dataloader import _worker_loop, _pin_memory_loop, default_collate, DataLoader, DataLoaderIter
except ImportError:
    DataLoader = object
    DataLoaderIter = object
    _worker_loop = object
    _pin_memory_loop = object
    default_collate = object

from jacinle.random import reset_global_seed, gen_rng

__all__ = ['JacDataLoader', 'JacDataLoaderIter']


def _worker_loop_seed(worker_id, dataset, index_queue, data_queue, collate_fn, seed, worker_init_fn, worker_init_args, worker_init_kwargs):
    reset_global_seed(seed)
    if worker_init_fn is not None:
        worker_init_fn(worker_id, *worker_init_args, **worker_init_kwargs)
    _worker_loop(dataset, index_queue, data_queue, collate_fn)


[docs] class JacDataLoaderIter(DataLoaderIter):
[docs] def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory self.done_event = threading.Event() self.worker_init_fn = loader.worker_init_fn self.worker_init_args = loader.worker_init_args self.worker_init_kwargs = loader.worker_init_kwargs self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.index_queue = multiprocessing.SimpleQueue() self.data_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.seeds = loader.gen_seeds() self.workers = [ multiprocessing.Process( target=_worker_loop_seed, args=(i, self.dataset, self.index_queue, self.data_queue, self.collate_fn, self.seeds[i], self.worker_init_fn, self.worker_init_args[i], self.worker_init_kwargs[i])) for i in range(self.num_workers)] for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() if self.pin_memory: in_data = self.data_queue self.data_queue = queue.Queue() self.pin_thread = threading.Thread( target=_pin_memory_loop, args=(in_data, self.data_queue, self.done_event)) self.pin_thread.daemon = True self.pin_thread.start() # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() else: if self.worker_init_fn is not None: self.worker_init_fn(-1, *self.worker_init_args, **self.worker_init_kwargs)
[docs] class JacDataLoader(DataLoader):
[docs] def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, seed=None, worker_init_fn=None, worker_init_args=None, worker_init_kwargs=None): super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last) self.worker_init_fn = worker_init_fn self.worker_init_args = worker_init_args self.worker_init_kwargs = worker_init_kwargs if num_workers > 0: self.seed_generator = gen_rng(seed) self.worker_init_args = worker_init_args if worker_init_args is not None else [tuple() for _ in range(num_workers)] self.worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else [{} for _ in range(num_workers)] else: self.worker_init_args = worker_init_args if worker_init_args is not None else tuple() self.worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else {}
def __iter__(self): return JacDataLoaderIter(self)
[docs] def gen_seeds(self): assert self.num_workers > 0 return self.seed_generator.randint(4294967296, size=self.num_workers).tolist()