Source code for jactorch.nn.gumbel_softmax.functional

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : functional.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 04/01/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch
import torch.nn.functional as F
from jacinle.utils.enum import JacEnum

from jactorch.functional.indexing import set_index_one_hot_, one_hot_nd
from jactorch.functional.masking import masked_softmax

__all__ = ['greedy_softmax', 'gumbel_softmax', 'SoftmaxImplmentation', 'general_softmax']


def _sample_gumbel(shape, eps=1e-10, out=None):
    """
    Sample from Gumbel(0, 1)
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
    return - torch.log(eps - torch.log(U + eps))


def _gumbel_softmax_sample(logits, dim=-1, tau=1, eps=1e-10, mask=None):
    """
    Draw a sample from the Gumbel-Softmax distribution
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    (MIT license)
    """
    gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=logits.new())
    y = logits + gumbel_noise
    return masked_softmax(y / tau, mask, dim=dim)


[docs] def greedy_softmax(logits, dim=-1, mask=None): # TODO(Jiayuan Mao @ 07/29): add support for dim != -1. assert dim == -1, 'Greedy softmax support only dim=-1' if mask is not None: probs = masked_softmax(logits, mask=mask, dim=dim) else: probs = logits # we only need to take the max one_hot = one_hot_nd(probs.max(dim)[1], logits.size(dim)) return one_hot
[docs] def gumbel_softmax(logits, dim=-1, tau=1, hard=False, mask=None, eps=1e-10): """ Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs dim: along which dim the softmax is performed tau: non-negative scalar temperature hard: if True, take argmax, but differentiate w.r.t. soft sample y eps: eps Returns: [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probability distribution that sums to 1 across classes Based on https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , (MIT license) """ y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps) if hard: with torch.no_grad(): _, k = y_soft.max(dim=dim) # this bit is based on # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5 y_hard = torch.zeros_like(logits) y_hard.requires_grad = False set_index_one_hot_(y_hard, dim, k, 1.0) # this cool bit of code achieves two things: # - makes the output value exactly one-hot (since we add then # subtract y_soft value) # - makes the gradient equal to y_soft gradient (since we strip # all other gradients) y = (y_hard - y_soft).detach() + y_soft else: y = y_soft return y
[docs] class SoftmaxImplmentation(JacEnum): STANDARD = 'standard' GUMBEL = 'gumbel' GUMBEL_HARD = 'gumbel_hard'
[docs] def general_softmax(logits, dim=-1, tau=1, impl='standard', mask=None, training=False): impl = SoftmaxImplmentation.from_string(impl) if impl is SoftmaxImplmentation.STANDARD: return masked_softmax(logits / tau, dim=dim) elif impl in (SoftmaxImplmentation.GUMBEL, SoftmaxImplmentation.GUMBEL_HARD): if not training: # no need to use logits / tau return greedy_softmax(logits, dim=dim, mask=mask) if impl is SoftmaxImplmentation.GUMBEL: return gumbel_softmax(logits, dim=dim, tau=tau, hard=False, mask=mask) else: return gumbel_softmax(logits, dim=dim, tau=tau, hard=True, mask=mask)