jactorch.data.collate.collate_v3#

Classes

VarLengthCollateV3

Collate a batch of data from multiple workers.

VarLengthCollateV3Mode

Class VarLengthCollateV3

class VarLengthCollateV3[source]#

Bases: object

Collate a batch of data from multiple workers. It supports data of variant length. For example, a batch may contain sentences of different length to be processed using LSTM models. Usually, we choose the pad the shorter sentences to make them of the same length. Thus, they can be processed in a batch.

To archive this, this module provides a fine-grained collate control over each input field and supports multiple ways for collating the data. It assumes that the input data is a dict. Example:

>>> collate_fn = VarLengthCollateV2({'sentence': 'pad', 'image': 'padimage'})
>>> collate_fn({
>>>     'sentence': [torch.rand(3), torch.rand(4)],
>>>     'image': [torch.rand(3, 16, 14), torch.rand(3, 8, 12)]
>>> })

It can be directly passed to the DataLaoder as the parameter collate_fn.

>>> from torch.utils.data.dataloader import DataLoader
>>> from torch.utils.data.dataset import Dataset
>>> dataset = Dataset()
>>> collate_fn = VarLengthCollateV2({'sentence': 'pad', 'image': 'padimage'})
>>> dataloader = DataLoader(dataset, collate_fn=collate_fn)

Here is a complete list of the supported collate mode:

1. skip: the field will be skipped, no collation will be done. This is useful when sometimes you are trasmitting some meta information to the model. 2. concat: assume the data is one-dimentional. The data will be concatenated along this dimension. 3. pad: assume the data is one-dimensional. The data will be padded into the same length (the maximum length of all data) and get concatenated along a new dimension. 4. pad2d: similar to the pad mode, it takes 2d inputs (h, w) and pads them. 5. padimage: similar to the pad2d, except that it takes 3d inputs (d, h, w), where the d dimension will not be padded.

__call__(batch, flatten_key=None, layout_spec=None)[source]#

Call self as a function.

__init__(layout, mode='collate', gather_device=None, gather_dim=0)[source]#
__new__(**kwargs)#

Class VarLengthCollateV3Mode

class VarLengthCollateV3Mode[source]#

Bases: JacEnum

__new__(value)#
classmethod assert_valid(value)#

Assert if the value is a valid choice.

classmethod choice_names()#

Returns the list of the name of all possible choices.

classmethod choice_objs()#

Returns the list of the object of all possible choices.

classmethod choice_values()#

Returns the list of the value of all possible choices.

classmethod from_string(value)#
Parameters:

value (str | JacEnum)

Return type:

JacEnum

classmethod is_valid(value)#

Check if the value is a valid choice.

classmethod type_name()#

Return the type name of the enum.

COLLATE = 'collate'#
GATHER = 'gather'#