Source code for jactorch.nn.neural_logic.modules.dimension
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : dimension.py
# Author : Honghua Dong
# Email : dhh19951@gmail.com
# Date : 04/20/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import itertools
import torch
import torch.nn as nn
from jactorch.functional import broadcast
from ._utils import exclude_mask, mask_value
__all__ = ['Expander', 'Reducer', 'Permutation']
[docs]
class Expander(nn.Module):
[docs]
def __init__(self, dim):
super().__init__()
self.dim = dim
[docs]
def forward(self, input, n=None):
if self.dim == 0:
assert n is not None
elif n is None:
n = input.size(self.dim)
dim = self.dim + 1
return broadcast(input.unsqueeze(dim), dim, n)
[docs]
def get_output_dim(self, input_dim):
return input_dim
[docs]
class Reducer(nn.Module):
[docs]
def __init__(self, dim, exclude_self=True, exists=True, min_val=0., max_val=0.):
super().__init__()
self.dim = dim
self.exclude_self = exclude_self
self.exists = exists
self.min_val = min_val
self.max_val = max_val
[docs]
def forward(self, input, mask=None):
shape = input.size()
inp0, inp1 = input, input
if self.exclude_self:
mask_self = exclude_mask(input, cnt=self.dim, dim=-1 - self.dim)
if mask is not None:
mask = mask.unsqueeze(-1) * mask_self
else:
mask = mask_self
if mask is not None:
inp0 = mask_value(input, mask, self.min_val)
inp1 = mask_value(input, mask, self.max_val)
if self.exists:
shape = shape[:-2] + (shape[-1] * 2, )
exists = torch.max(inp0, dim=-2)[0]
forall = torch.min(inp1, dim=-2)[0]
return torch.stack((exists, forall), dim=-1).view(shape)
shape = shape[:-2] + (shape[-1], )
return torch.max(inp0, dim=-2)[0].view(shape)
[docs]
def get_output_dim(self, input_dim):
if self.exists:
return input_dim * 2
return input_dim
[docs]
class Permutation(nn.Module):
[docs]
def __init__(self, dim):
super().__init__()
self.dim = dim
[docs]
def forward(self, input):
if self.dim <= 1:
return input
nr_dims = len(input.size())
# Assume the last dim is channel.
index = tuple(range(nr_dims - 1))
start_dim = nr_dims - 1 - self.dim
assert start_dim > 0
res = []
for i in itertools.permutations(index[start_dim:]):
p = index[:start_dim] + i + (nr_dims - 1,)
res.append(input.permute(p))
return torch.cat(res, dim=-1)
[docs]
def get_output_dim(self, input_dim):
mul = 1
for i in range(self.dim):
mul *= i + 1
return input_dim * mul