Source code for jactorch.functional.grad

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : grad.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

"""Gradient modification functions."""

from torch.autograd import Function

__all__ = ['grad_multi', 'zero_grad']


[docs] class GradMulti(Function):
[docs] @staticmethod def forward(ctx, input, grad_multi): ctx.grad_multi = grad_multi output = input.clone() return output
[docs] @staticmethod def backward(ctx, grad_output): return grad_output * ctx.grad_multi, None
[docs] def grad_multi(input, grad_multi): """ Scale the gradient with respect to the input. Args: input (Tensor): the input tensor. grad_multi (float): the constant for scaling up the gradient. Returns Tensor: of the same value as the input. But during the back-propagation, it will scale the gradient by `grad_multi`. """ return GradMulti.apply(input, grad_multi)
[docs] class ZeroGradV1(Function):
[docs] @staticmethod def forward(ctx, input): output = input.clone() return output
[docs] @staticmethod def backward(ctx, grad_output): return None
[docs] def zero_grad_v1(input): """Zero-grad the variable.""" return ZeroGradV1.apply(input)
[docs] def zero_grad_v2(v): """Zero-grad the variable.""" return v.detach()
zero_grad = zero_grad_v2