Source code for jactorch.utils.meta

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : meta.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import functools
import numpy as np

import torch

from jacinle.utils.deprecated import deprecated
from jacinle.utils.meta import stmap

SKIP_TYPES = (str, bytes)

__all__ = ['mark_volatile', 'as_tensor', 'as_variable', 'as_numpy', 'as_float', 'as_cuda', 'as_cpu', 'as_detached']


def _mark_volatile(o):
    from torch.autograd import Variable
    if torch.is_tensor(o):
        o = Variable(o)
    if isinstance(o, Variable):
        o.volatile = True
    return o


[docs] @deprecated def mark_volatile(obj): """DEPRECATED(Jiayuan Mao): mark_volatile has been deprecated and will be removed by 10/23/2018; please use torch.no_grad instead.""" return stmap(_mark_volatile, obj)
def _as_tensor(o): from torch.autograd import Variable if isinstance(o, SKIP_TYPES): return o if isinstance(o, Variable): return o if torch.is_tensor(o): return o return torch.from_numpy(np.array(o))
[docs] def as_tensor(obj): """Convert elements in a Python data structure to tensors. Supported types: tensor, variable (PyTorch 0.x), numpy array, and Python scalars.""" return stmap(_as_tensor, obj)
def _as_variable(o): from torch.autograd import Variable if isinstance(o, SKIP_TYPES): return o if isinstance(o, Variable): return o if not torch.is_tensor(o): o = torch.from_numpy(np.array(o)) return Variable(o)
[docs] @deprecated def as_variable(obj): """DEPRECATED(Jiayuan Mao): as_variable has been deprecated and will be removed by 10/23/2018; please use as_tensor instead.""" return stmap(_as_variable, obj)
def _as_numpy(o): from torch.autograd import Variable if isinstance(o, SKIP_TYPES): return o if isinstance(o, Variable): o = o if torch.is_tensor(o): return o.detach().cpu().numpy() return np.array(o)
[docs] def as_numpy(obj): """Convert elements in a Python data structure to numpy arrays. Supported types: tensor, variable (PyTorch 0.x), numpy array, and Python scalars.""" return stmap(_as_numpy, obj)
def _as_float(o): if isinstance(o, SKIP_TYPES): return o if torch.is_tensor(o): return o.item() arr = as_numpy(o) assert arr.size == 1 return float(arr)
[docs] def as_float(obj): """Convert elements in a Python data structure to Python floating-point scalars. Supported types: tensor, variable (PyTorch 0.x), numpy array, and Python scalars.""" return stmap(_as_float, obj)
def _as_cpu(o): from torch.autograd import Variable if isinstance(o, Variable) or torch.is_tensor(o): return o.cpu() return o
[docs] def as_cpu(obj): """Move elements in a Python data structure to CPU. Only changes tensors and variables (PyTorch 0.x).""" return stmap(_as_cpu, obj)
def _as_cuda(o): from torch.autograd import Variable if isinstance(o, Variable) or torch.is_tensor(o): return o.cuda() return o
[docs] def as_cuda(obj): """Move elements in a Python data structure to CPU. Only changes tensors and variables (PyTorch 0.x).""" return stmap(_as_cuda, obj)
def _as_detached(o, clone=False): from torch.autograd import Variable if isinstance(o, Variable) or torch.is_tensor(o): if clone: return o.clone().detach() return o.detach() return o
[docs] def as_detached(obj, clone: bool = False): """Detach elements in a Python data structure. Only changes tensors and variables (PyTorch 0.x). Args: clone: if True, clone the tensor before detaching. """ return stmap(functools.partial(_as_detached, clone=clone), obj)