Source code for jactorch.parallel.dict_gather

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : dict_gather.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/04/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import functools
import collections

import torch
from torch.autograd import Variable
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel._functions import Gather
from jactorch.data.collate.collate_v3 import VarLengthCollateV3


__all__ = [
    'data_parallel_dict_gather',
    'DictGatherDataParallel',
    'patch_dict_gathering',
    'dict_gather_v1', 'dict_gather_v2'
]


[docs] def data_parallel_dict_gather(data_parallel, outputs, output_device, layout=None): return dict_gather_v2(outputs, output_device, dim=data_parallel.dim, layout=layout)
[docs] class DictGatherDataParallel(DataParallel): """Add support for modules that return dicts."""
[docs] def gather(self, outputs, output_device): return data_parallel_dict_gather(self, outputs, output_device)
[docs] def patch_dict_gathering(data_parallel): assert isinstance(data_parallel, DataParallel) data_parallel.gather = functools.partial(data_parallel_dict_gather, data_parallel=data_parallel)
[docs] def dict_gather_v1(outputs, target_device, dim=0): """Gathers variables from different GPUs on a specified device (-1 means the CPU), with dictionary support. """ def gather_map(outputs): out = outputs[0] if isinstance(out, Variable) or torch.is_tensor(out): if out.dim() == 0: outputs = [o.unsqueeze(0) for o in outputs] return Gather.apply(target_device, dim, *outputs) elif out is None: return None elif isinstance(out, collections.Mapping): return {k: gather_map([o[k] for o in outputs]) for k in out} elif isinstance(out, (str, bytes)): return outputs elif isinstance(out, collections.Sequence): return type(out)(map(gather_map, zip(*outputs))) return outputs return gather_map(outputs)
[docs] def dict_gather_v2(outputs, target_device, layout=None, dim=0): if layout is None: return dict_gather_v1(outputs, target_device, dim=dim) return VarLengthCollateV3(layout, mode='gather', gather_device=target_device, gather_dim=dim)(outputs)