jactorch.nn.gumbel_softmax.functional#
Classes
Functions
|
|
|
|
|
Sample from the Gumbel-Softmax distribution and optionally discretize. |
Class SoftmaxImplmentation
- class SoftmaxImplmentation[source]#
Bases:
JacEnum
- __new__(value)#
- classmethod assert_valid(value)#
Assert if the value is a valid choice.
- classmethod choice_names()#
Returns the list of the name of all possible choices.
- classmethod choice_objs()#
Returns the list of the object of all possible choices.
- classmethod choice_values()#
Returns the list of the value of all possible choices.
- classmethod is_valid(value)#
Check if the value is a valid choice.
- classmethod type_name()#
Return the type name of the enum.
- GUMBEL = 'gumbel'#
- GUMBEL_HARD = 'gumbel_hard'#
- STANDARD = 'standard'#
Functions
- gumbel_softmax(logits, dim=-1, tau=1, hard=False, mask=None, eps=1e-10)[source]#
Sample from the Gumbel-Softmax distribution and optionally discretize.
- Parameters:
logits – [batch_size, n_class] unnormalized log-probs
dim – along which dim the softmax is performed
tau – non-negative scalar temperature
hard – if True, take argmax, but differentiate w.r.t. soft sample y
eps – eps
- Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probability distribution that sums to 1 across classes
Based on ericjang/gumbel-softmax , (MIT license)