jactorch.nn.gumbel_softmax.functional#

Classes

Functions

general_softmax(logits[, dim, tau, impl, ...])

greedy_softmax(logits[, dim, mask])

gumbel_softmax(logits[, dim, tau, hard, ...])

Sample from the Gumbel-Softmax distribution and optionally discretize.

Class SoftmaxImplmentation

class SoftmaxImplmentation[source]#

Bases: JacEnum

__new__(value)#
classmethod assert_valid(value)#

Assert if the value is a valid choice.

classmethod choice_names()#

Returns the list of the name of all possible choices.

classmethod choice_objs()#

Returns the list of the object of all possible choices.

classmethod choice_values()#

Returns the list of the value of all possible choices.

classmethod from_string(value)#
Parameters:

value (str | JacEnum)

Return type:

JacEnum

classmethod is_valid(value)#

Check if the value is a valid choice.

classmethod type_name()#

Return the type name of the enum.

GUMBEL = 'gumbel'#
GUMBEL_HARD = 'gumbel_hard'#
STANDARD = 'standard'#

Functions

general_softmax(logits, dim=-1, tau=1, impl='standard', mask=None, training=False)[source]#
greedy_softmax(logits, dim=-1, mask=None)[source]#
gumbel_softmax(logits, dim=-1, tau=1, hard=False, mask=None, eps=1e-10)[source]#

Sample from the Gumbel-Softmax distribution and optionally discretize.

Parameters:
  • logits – [batch_size, n_class] unnormalized log-probs

  • dim – along which dim the softmax is performed

  • tau – non-negative scalar temperature

  • hard – if True, take argmax, but differentiate w.r.t. soft sample y

  • eps – eps

Returns:

[batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probability distribution that sums to 1 across classes

Based on ericjang/gumbel-softmax , (MIT license)