Source code for jactorch.functional.quantization
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : quantization.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 06/09/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
"""Basic quantization functions with a straight-through gradient estimator."""
__all__ = ['quantize', 'randomized_quantize']
import torch
import torch.autograd as ag
class _Quantize(ag.Function):
@staticmethod
def forward(ctx, x):
return (x > 0.5).float()
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _RandomizedQuantize(ag.Function):
@staticmethod
def forward(ctx, x):
rand = torch.rand(x.size())
return (rand > x).float()
@staticmethod
def backward(ctx, grad_output):
return grad_output
[docs]
def quantize(x: torch.Tensor) -> torch.Tensor:
"""Quantize a tensor to binary values: ``(x > 0.5).float()``. This function implements the straight-through gradient estimator.
Args:
x: the input tensor.
Returns:
the quantized tensor.
"""
return _Quantize.apply(x)
[docs]
def randomized_quantize(x: torch.Tensor) -> torch.Tensor:
"""Quantize a tensor to binary values: ``(rand() > x).float()``. This function implements the straight-through gradient estimator.
Args:
x: the input tensor.
Returns:
the quantized tensor.
"""
return _RandomizedQuantize.apply(x)