jactorch.data.collate.collate_v2#
Classes
Collate a batch of data from multiple workers. |
Class VarLengthCollateV2
- class VarLengthCollateV2[source]#
Bases:
object
Collate a batch of data from multiple workers. It supports data of variant length. For example, a batch may contain sentences of different length to be processed using LSTM models. Usually, we choose the pad the shorter sentences to make them of the same length. Thus, they can be processed in a batch.
To archive this, this module provides a fine-grained collate control over each input field and supports multiple ways for collating the data. It assumes that the input data is a dict. Example:
>>> collate_fn = VarLengthCollateV2({'sentence': 'pad', 'image': 'padimage'}) >>> collate_fn({ >>> 'sentence': [torch.rand(3), torch.rand(4)], >>> 'image': [torch.rand(3, 16, 14), torch.rand(3, 8, 12)] >>> })
It can be directly passed to the DataLaoder as the parameter collate_fn.
>>> from torch.utils.data.dataloader import DataLoader >>> from torch.utils.data.dataset import Dataset >>> dataset = Dataset() >>> collate_fn = VarLengthCollateV2({'sentence': 'pad', 'image': 'padimage'}) >>> dataloader = DataLoader(dataset, collate_fn=collate_fn)
Here is a complete list of the supported collate mode:
1. skip: the field will be skipped, no collation will be done. This is useful when sometimes you are trasmitting some meta information to the model. 2. concat: assume the data is one-dimentional. The data will be concatenated along this dimension. 3. pad: assume the data is one-dimensional. The data will be padded into the same length (the maximum length of all data) and get concatenated along a new dimension. 4. pad2d: similar to the pad mode, it takes 2d inputs (h, w) and pads them. 5. padimage: similar to the pad2d, except that it takes 3d inputs (d, h, w), where the d dimension will not be padded. 6. stack: this is the default mode. It assumes the data is a list of tensors. The data will be stacked into a tensor of shape (batch_size, …).
- __new__(**kwargs)#