#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : collate_v2.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/09/2020
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import re
import collections
import torch
from jacinle.utils.argument import UniqueValueGetter
from .utils import use_shared_memory, numpy_type_map, VarLengthCollateMode
__all__ = ['VarLengthCollateV2']
[docs]
class VarLengthCollateV2(object):
"""
Collate a batch of data from multiple workers.
It supports data of variant length. For example, a batch may contain sentences of different length to be
processed using LSTM models. Usually, we choose the pad the shorter sentences to make them of the same length.
Thus, they can be processed in a batch.
To archive this, this module provides a fine-grained collate control over each input field and supports multiple
ways for collating the data. It assumes that the input data is a dict. Example:
>>> collate_fn = VarLengthCollateV2({'sentence': 'pad', 'image': 'padimage'})
>>> collate_fn({
>>> 'sentence': [torch.rand(3), torch.rand(4)],
>>> 'image': [torch.rand(3, 16, 14), torch.rand(3, 8, 12)]
>>> })
It can be directly passed to the DataLaoder as the parameter `collate_fn`.
>>> from torch.utils.data.dataloader import DataLoader
>>> from torch.utils.data.dataset import Dataset
>>> dataset = Dataset()
>>> collate_fn = VarLengthCollateV2({'sentence': 'pad', 'image': 'padimage'})
>>> dataloader = DataLoader(dataset, collate_fn=collate_fn)
Here is a complete list of the supported collate mode:
1. skip: the field will be skipped, no collation will be done. This is useful when sometimes you are trasmitting
some meta information to the model.
2. concat: assume the data is one-dimentional. The data will be concatenated along this dimension.
3. pad: assume the data is one-dimensional. The data will be padded into the same length (the maximum length of all
data) and get concatenated along a new dimension.
4. pad2d: similar to the pad mode, it takes 2d inputs (h, w) and pads them.
5. padimage: similar to the pad2d, except that it takes 3d inputs (d, h, w), where the d dimension will not be
padded.
6. stack: this is the default mode. It assumes the data is a list of tensors. The data will be stacked into a
tensor of shape (batch_size, ...).
"""
[docs]
def __init__(self, fields):
self._fields = fields
[docs]
def __call__(self, batch, key=None):
error_msg = "Batch must contain tensors, numbers, dicts or lists; found {}."
elem_type = type(batch[0])
if key is not None:
assert torch.is_tensor(batch[0]) or (elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_'
and elem_type.__name__ != 'string_'), 'Invalid field: {}.'.format(key)
if torch.is_tensor(batch[0]):
return self._stack(batch, key)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return self._stack([torch.from_numpy(b) for b in batch], key)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], (str, bytes)):
return batch
elif isinstance(batch[0], collections.abc.Mapping):
result = {}
for key in batch[0]:
values = [d[key] for d in batch]
if key in self._fields:
if isinstance(self._fields[key], (str, bytes)) and VarLengthCollateMode.from_string(self._fields[key]) is VarLengthCollateMode.SKIP:
result[key] = values
elif isinstance(self._fields[key], (str, bytes)) and VarLengthCollateMode.from_string(self._fields[key]) is VarLengthCollateMode.STACK:
result[key] = self(values)
else:
values, lengths = self(values, key=key)
result[key] = values
result[key + '_length'] = lengths
else:
result[key] = self(values)
return result
elif isinstance(batch[0], collections.abc.Sequence):
transposed = zip(*batch)
return [self(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))
def _stack(self, values, key=None):
mode, parameters = None, None
if key is not None:
mode_spec = self._fields[key]
if isinstance(mode_spec, tuple):
mode = VarLengthCollateMode.from_string(mode_spec[0])
parameters = mode_spec[1:]
else:
mode = VarLengthCollateMode.from_string(mode_spec)
parameters = tuple()
if use_shared_memory():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = 0
if key is not None:
if mode is VarLengthCollateMode.PAD:
numel = max([x.numel() for x in values]) * len(values)
elif mode is VarLengthCollateMode.CONCAT:
numel = sum([x.numel() for x in values])
elif mode is VarLengthCollateMode.PAD2D:
max_h = max([x.size(0) for x in values])
max_w = max([x.size(1) for x in values])
hw = max_h * max_w
numel = sum([x.numel() // x.size(0) // x.size(1) * hw for x in values])
else:
numel = sum([x.numel() for x in values])
if numel > 0:
pass
# TODO(Jiayuan Mao @ 2023/03/18): add back out=out optimization.
# storage = values[0].storage()._new_shared(numel)
# out = values[0].new(storage)
if key is None:
return torch.stack(values, 0)
if mode is VarLengthCollateMode.CONCAT:
uvg = UniqueValueGetter('Tensor sizes should match except the first dim.')
for v in values:
uvg.set(v.size()[1:])
lengths = [v.size(0) for v in values]
return torch.cat(values, 0), torch.LongTensor(lengths)
elif mode is VarLengthCollateMode.PAD:
uvg = UniqueValueGetter('Tensor sizes should match except the first dim.')
for v in values:
uvg.set(v.size()[1:])
pad_value = parameters[0] if len(parameters) > 0 else 0
lengths = [v.size(0) for v in values]
max_length = max(lengths)
result = []
for v in values:
if v.size(0) < max_length:
v = torch.cat([v, v.new(*((max_length - v.size(0), ) + v.size()[1:])).fill_(pad_value)], dim=0)
result.append(v)
return torch.stack(result, 0), torch.LongTensor(lengths)
elif mode is VarLengthCollateMode.PAD2D:
uvg = UniqueValueGetter('Tensor sizes should match except the first 2 dims.')
for v in values:
uvg.set(v.size()[2:])
rest_size = uvg.get() or []
pad_value = parameters[0] if len(parameters) > 0 else 0
lengths = [v.size()[:2] for v in values]
max_h, max_w = max([x[0] for x in lengths]), max([x[1] for x in lengths])
result = []
for v in values:
u = v.new(*(max_h, max_w, *rest_size)).fill_(pad_value)
u[:v.size(0), :v.size(1)] = v
result.append(u)
return torch.stack(result, 0), torch.LongTensor(lengths)
elif mode is VarLengthCollateMode.PADIMAGE:
uvg = UniqueValueGetter('Tensor sizes should match except the last 2 dims.')
for v in values:
assert v.dim() == 3, 'Support only 3-dimensional input.'
uvg.set(v.size(0))
pad_value = parameters[0] if len(parameters) > 0 else 0
lengths = [v.size()[-2:] for v in values]
max_h, max_w = max([x[0] for x in lengths]), max([x[1] for x in lengths])
result = []
for v in values:
u = v.new(*(uvg.get(), max_h, max_w)).fill_(pad_value)
# TODO(Jiayuan Mao @ 07/19): support input with dim > 3.
u[:, :v.size(1), :v.size(2)] = v
result.append(u)
return torch.stack(result, 0), torch.LongTensor(lengths)
else:
raise ValueError('Unknown collation mode: {}.'.format(mode))