jactorch.parallel.replication_callback#
Classes
Data Parallel with a replication callback. |
Functions
Execute an replication callback __data_parallel_replicate__ on each module created by original replication. |
|
|
Monkey-patch an existing DataParallel object. |
Class CallbackContext
Class ReplicationCallbackDataParallel
- class ReplicationCallbackDataParallel[source]#
Bases:
DataParallel
Data Parallel with a replication callback.
An replication callback __data_parallel_replicate__ of each module will be invoked after being created by original replicate function. The callback will be invoked with arguments __data_parallel_replicate__(ctx, copy_id)
Examples
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked.
- __init__(module, device_ids=None, output_device=None, dim=0)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(*inputs, **kwargs)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- gather(outputs, output_device)#
- parallel_apply(replicas, inputs, kwargs)#
Functions
- exec_data_parallel_replication_callback(modules)[source]#
Execute an replication callback __data_parallel_replicate__ on each module created by original replication.
The callback will be invoked with arguments __data_parallel_replicate__(ctx, copy_id)
Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies.
- patch_replication_callback(data_parallel)[source]#
Monkey-patch an existing DataParallel object. Add the replication callback. Useful when you have customized DataParallel implementation.
Examples
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1])