Source code for jactorch.functional.range
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : meshgrid.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/31/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
"""Numeric range functions."""
import torch
from .shape import broadcast, concat_shape
__all__ = ['meshgrid', 'meshgrid_exclude_self']
[docs]
def meshgrid(input1, input2=None, dim=-1):
"""Perform np.meshgrid along given axis. It will generate a new dimension after dim."""
if input2 is None:
input2 = input1
if dim < 0:
dim += input1.dim()
n, m = input1.size(dim), input2.size(dim)
x = broadcast(input1.unsqueeze(dim + 1), dim + 1, m)
y = broadcast(input2.unsqueeze(dim + 0), dim + 0, n)
return x, y
[docs]
def meshgrid_exclude_self(input, dim=1):
"""
Exclude self from the grid. Specifically, given an array a[i, j] of n * n, it produces
a new array with size n * (n - 1) where only a[i, j] (i != j) is preserved.
The operation is performed over dim and dim +1 axes.
"""
if dim < 0:
dim += input.dim()
n = input.size(dim)
assert n == input.size(dim + 1)
# exclude self-attention
rng = torch.arange(0, n, dtype=torch.long, device=input.device)
rng_n1 = rng.unsqueeze(1).expand((n, n))
rng_1n = rng.unsqueeze(0).expand((n, n))
mask_self = (rng_n1 != rng_1n)
for i in range(dim):
mask_self.unsqueeze_(0)
for j in range(input.dim() - dim - 2):
mask_self.unsqueeze_(-1)
target_shape = concat_shape(input.size()[:dim], n, n-1, input.size()[dim+2:])
return input.masked_select(mask_self).view(target_shape)