jactorch.parallel.replication_callback#

Classes

CallbackContext

ReplicationCallbackDataParallel

Data Parallel with a replication callback.

Functions

exec_data_parallel_replication_callback(modules)

Execute an replication callback __data_parallel_replicate__ on each module created by original replication.

patch_replication_callback(data_parallel)

Monkey-patch an existing DataParallel object.

Class CallbackContext

class CallbackContext[source]#

Bases: object

__init__()#
__new__(**kwargs)#

Class ReplicationCallbackDataParallel

class ReplicationCallbackDataParallel[source]#

Bases: DataParallel

Data Parallel with a replication callback.

An replication callback __data_parallel_replicate__ of each module will be invoked after being created by original replicate function. The callback will be invoked with arguments __data_parallel_replicate__(ctx, copy_id)

Examples

> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked.

__init__(module, device_ids=None, output_device=None, dim=0)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(*inputs, **kwargs)#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Any

gather(outputs, output_device)#
Parameters:
Return type:

Any

parallel_apply(replicas, inputs, kwargs)#
Parameters:
Return type:

List[Any]

replicate(module, device_ids)[source]#
scatter(inputs, kwargs, device_ids)#
Parameters:
Return type:

Any

Functions

exec_data_parallel_replication_callback(modules)[source]#

Execute an replication callback __data_parallel_replicate__ on each module created by original replication.

The callback will be invoked with arguments __data_parallel_replicate__(ctx, copy_id)

Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information.

We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies.

patch_replication_callback(data_parallel)[source]#

Monkey-patch an existing DataParallel object. Add the replication callback. Useful when you have customized DataParallel implementation.

Examples

> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1])