Source code for jactorch.functional.sampling
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : sample.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/09/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
"""Sampling functions."""
import torch
import torch.autograd as autograd
from jacinle.utils.vendor import requires_vendors
from .indexing import one_hot_nd
__all__ = ['sample_bernoulli', 'sample_multinomial', 'choice']
[docs]
class SampleBernoulli(autograd.Function):
[docs]
@staticmethod
def forward(ctx, x):
rand = x.new(*x.size())
torch.rand(x.size(), out=rand)
return (rand > x).float()
[docs]
@staticmethod
def backward(ctx, grad_output):
return grad_output
[docs]
def sample_bernoulli(x):
"""Sample from a Bernoulli distribution.
Args:
x: the probability of the Bernoulli distribution.
Returns:
A tensor with the same shape as ``x``, where each element is sampled from the corresponding Bernoulli distribution.
"""
return SampleBernoulli.apply(x)
[docs]
class SampleMultinomial(autograd.Function):
[docs]
@staticmethod
def forward(ctx, x, dim):
x = x.transpose(dim, -1)
x_flatten = x.contiguous().view(-1, x.size(-1))
rand = torch.multinomial(x_flatten, 1).view(x.size()[:-1])
output = one_hot_nd(rand, x.size(dim))
output = output.transpose(dim, -1)
return output.float()
[docs]
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
[docs]
def sample_multinomial(x, dim=-1):
"""Sample from a multinomial distribution.
Args:
x: the probability of the multinomial distribution.
dim: the dimension of the categories.
Returns:
A tensor with the same shape as ``x``, where each element is sampled from the corresponding multinomial distribution.
"""
return SampleMultinomial.apply(x, dim)
[docs]
@requires_vendors('pytorch_reservoir')
def choice(a, k=1, replace=True, p=None, dtype=None, device=None):
"""
Generates a random sample from a given 1-D array.
Args:
a (torch.Tensor): 1-D tensor or int.
k (int, optional): number of samples to be drawn.
replace (bool, optional): whether the sample is with or without replacement.
p (torch.Tensor, optional): an optional weight parameter. Not necessarily to be normalized.
Returns:
torch.Tensor: 1-D outputs of k sampled data from `a`.
"""
import pytorch_reservoir
if isinstance(a, int):
a = torch.arange(a, dtype=dtype, device=device)
if not torch.is_tensor(a):
a = torch.tensor(a)
assert a.dim() == 1, 'jactorch.choice supports only 1-D input.'
a = a.to(dtype=dtype, device=device)
if p is not None:
assert a.size() == p.size()
return pytorch_reservoir.choice(a, p.to(device=device), replace, k)
return pytorch_reservoir.choice(a, replace, k)