Source code for jactorch.parallel.replication_callback
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : replication_callback.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/27/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'ReplicationCallbackDataParallel',
'exec_data_parallel_replication_callback',
'patch_replication_callback'
]
[docs]
class CallbackContext(object):
pass
[docs]
def exec_data_parallel_replication_callback(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
[docs]
class ReplicationCallbackDataParallel(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
[docs]
def replicate(self, module, device_ids):
modules = super(ReplicationCallbackDataParallel, self).replicate(module, device_ids)
exec_data_parallel_replication_callback(modules)
return modules
[docs]
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
exec_data_parallel_replication_callback(modules)
return modules
data_parallel.replicate = new_replicate