Source code for jactorch.parallel.comm

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : comm.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/27/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import queue
import collections
from typing import Any, Callable

from jacinle.concurrency.future import FutureResult

__all__ = ['SlavePipe', 'SyncMaster']


_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])


[docs] class SlavePipe(_SlavePipeBase): """Pipe for master-slave communication."""
[docs] def run_slave(self, msg: Any): self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) return ret
[docs] class SyncMaster(object): """An abstract `SyncMaster` object. - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. """
[docs] def __init__(self, master_callback: Callable): """Initialize the master. Args: master_callback: a callback to be invoked after having collected messages from slave devices. """ self._master_callback = master_callback self._queue = queue.Queue() self._registry = collections.OrderedDict() self._activated = False
[docs] def register_slave(self, identifier): """Register an slave device. Args: identifier: an identifier, usually is the device id. Returns: a `SlavePipe` object which can be used to communicate with the master device. """ if self._activated: assert self._queue.empty(), 'Queue is not clean before next initialization.' self._activated = False self._registry.clear() future = FutureResult() self._registry[identifier] = _MasterRegistry(future) return SlavePipe(identifier, self._queue, future)
[docs] def run_master(self, master_msg: Any): """Main entry for the master device in each forward pass. The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). Args: master_msg: the message that the master want to send to itself. This will be placed as the first message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. Returns: the message to be sent back to the master device. """ self._activated = True intermediates = [(0, master_msg)] for i in range(self.nr_slaves): intermediates.append(self._queue.get()) results = self._master_callback(intermediates) assert results[0][0] == 0, 'The first result should belongs to the master.' for i, res in results: if i == 0: continue self._registry[i].result.put(res) for i in range(self.nr_slaves): assert self._queue.get() is True return results[0][1]
@property def nr_slaves(self) -> int: """The number of slave devices.""" return len(self._registry)