Source code for jactorch.parallel.user_scattered
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : user_scattered.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/27/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch.cuda as cuda
from torch.nn.parallel.data_parallel import DataParallel
from jactorch.cuda.copy import async_copy_to
__all__ = ['UserScatteredDataParallel', 'use_user_scattered']
[docs]
class UserScatteredDataParallel(DataParallel):
use_copy_stream = True
[docs]
def scatter(self, inputs, kwargs, device_ids):
return use_user_scattered(inputs, kwargs, device_ids, use_stream=self.use_copy_stream)
[docs]
def use_user_scattered(inputs, kwargs, device_ids, use_stream=True):
assert len(inputs) == 1
inputs = inputs[0]
if use_stream:
inputs = _async_copy_stream(inputs, device_ids)
else:
inputs = _async_copy(inputs, device_ids)
inputs = [[i] for i in inputs]
assert len(kwargs) == 0
kwargs = [{} for _ in range(len(inputs))]
return inputs, kwargs
def _async_copy(inputs, device_ids):
nr_devs = len(device_ids)
assert type(inputs) in (tuple, list)
assert len(inputs) == nr_devs
outputs = []
for i, dev in zip(inputs, device_ids):
with cuda.device(dev):
outputs.append(async_copy_to(i, dev))
return tuple(outputs)
def _async_copy_stream(inputs, device_ids):
nr_devs = len(device_ids)
assert type(inputs) in (tuple, list)
assert len(inputs) == nr_devs
outputs = []
streams = [_get_stream(d) for d in device_ids]
for i, dev, stream in zip(inputs, device_ids, streams):
with cuda.device(dev):
main_stream = cuda.current_stream()
with cuda.stream(stream):
outputs.append(async_copy_to(i, dev, main_stream=main_stream))
main_stream.wait_stream(stream)
return outputs
"""Adapted from: torch/nn/parallel/_functions.py"""
# background streams used for copying
_streams = None
def _get_stream(device):
"""Gets a background stream for copying between CPU and GPU"""
global _streams
if device == -1:
return None
if _streams is None:
_streams = [None] * cuda.device_count()
if _streams[device] is None: _streams[device] = cuda.Stream(device)
return _streams[device]