Source code for jactorch.nn.gumbel_softmax.softmax
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : softmax.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/01/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch.nn as nn
from .functional import gumbel_softmax, greedy_softmax, general_softmax, SoftmaxImplmentation
__all__ = ['SoftmaxImplmentation', 'GumbelSoftmax', 'GeneralSoftmax']
[docs]
class GumbelSoftmax(nn.Module):
[docs]
def __init__(self, dim=-1, tau=1.0, hard=False, eps=1e-10):
super().__init__()
self.dim = dim
self.tau = tau
self.hard = hard
self.eps = eps
[docs]
def forward(self, logits, mask=None):
if self.training:
return gumbel_softmax(logits, dim=self.dim, tau=self.tau, hard=self.hard, mask=mask, eps=self.eps)
else:
return greedy_softmax(logits, dim=self.dim, mask=mask)
[docs]
class GeneralSoftmax(nn.Module):
[docs]
def __init__(self, dim=-1, tau=1.0, impl='standard'):
super().__init__()
self.dim = dim
self.tau = tau
self.impl = SoftmaxImplmentation.from_string(impl)
[docs]
def forward(self, logits, mask=None):
return general_softmax(logits, dim=self.dim, tau=self.tau, impl=self.impl, mask=mask, training=self.training)