Source code for jactorch.data.collate.collate_v1

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : collate_v1.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 09/21/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

"""
Function for collate data generated from multiple workers.
It supports collating data with variant length.
"""

import re
import collections

import torch

from jacinle.utils.argument import UniqueValueGetter
from .utils import use_shared_memory, numpy_type_map, VarLengthCollateMode

__all__ = ['VarLengthCollateV1']


[docs] class VarLengthCollateV1(object):
[docs] def __init__(self, fields, mode='pad'): """ VarLengthCollate is deprecated in favor of VarLengthCollateV2. Args: fields: the fields to be collated with variant length. mode: the collate mode. """ self._fields = fields self._mode = VarLengthCollateMode.from_string(mode)
[docs] def __call__(self, batch, process=False): error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) if process: assert torch.is_tensor(batch[0]) or (elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' and elem_type.__name__ != 'string_') if torch.is_tensor(batch[0]): return self._stack(batch, process) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if re.search('[SaUO]', elem.dtype.str) is not None: raise TypeError(error_msg.format(elem.dtype)) return self._stack([torch.from_numpy(b) for b in batch], process) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], int): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], (str, bytes)): return batch elif isinstance(batch[0], collections.Mapping): result = {} for key in batch[0]: values = [d[key] for d in batch] if key in self._fields: values, lengths = self(values, True) result[key] = values result[key + '_length'] = lengths else: result[key] = self(values) return result elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [self(samples) for samples in transposed] raise TypeError((error_msg.format(type(batch[0]))))
def _stack(self, values, process): uvg = UniqueValueGetter('Tensor sizes should match except the first dim.') for v in values: uvg.set(v.size()[1:]) uvg.get() out = None if use_shared_memory(): # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = 0 if process: if self._mode is VarLengthCollateMode.PAD: numel = max([x.numel() for x in values]) * len(values) elif self._mode is VarLengthCollateMode.CONCAT: numel = sum([x.numel() for x in values]) elif self._mode is VarLengthCollateMode.PAD2D: max_h = max([x.size(0) for x in values]) max_w = max([x.size(1) for x in values]) hw = max_h * max_w numel = sum([x.numel() // x.size(0) // x.size(1) * hw for x in values]) else: numel = sum([x.numel() for x in values]) if numel > 0: storage = values[0].storage()._new_shared(numel) out = values[0].new(storage) if not process: return torch.stack(values, 0, out=out) lengths = [v.size(0) for v in values] if self._mode is VarLengthCollateMode.CONCAT: return torch.cat(values, 0, out=out), torch.LongTensor(lengths) elif self._mode is VarLengthCollateMode.PAD: max_length = max(lengths) result = [] for v in values: if v.size(0) < max_length: v = torch.cat([v, v.new(*((max_length - v.size(0), ) + v.size()[1:])).zero_()], dim=0) result.append(v) return torch.stack(result, 0, out=out), torch.LongTensor(lengths) else: raise ValueError('Unknown collation mode: {}.'.format(self._mode))