Source code for jactorch.cuda.copy

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : copy.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import collections

import torch

__all__ = ['async_copy_to']


[docs] def async_copy_to(obj, dev, main_stream=None): """ Copy an object to a specific device asynchronizedly. If the param `main_stream` is provided, the copy stream will be synchronized with the main one. Args: obj (Iterable[Tensor] or Tensor): a structure (e.g., a list or a dict) containing pytorch tensors. dev (int): the target device. main_stream (stream): the main stream to be synchronized. Returns: a deep copy of the data structure, with each tensor copied to the device. """ # Adapted from: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/_functions.py if torch.is_tensor(obj): v = obj.cuda(dev, non_blocking=True) if main_stream is not None: v.record_stream(main_stream) return v elif isinstance(obj, collections.abc.Mapping): return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} elif isinstance(obj, (tuple, list)): return [async_copy_to(o, dev, main_stream) for o in obj] else: return obj