Source code for jactorch.nn.neural_logic.modules.neural_logic

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : neural_logic.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/28/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch.nn as nn

from jacinle.utils.enum import JacEnum
from jactorch.nn.cnn import MLPLayer
from jactorch.nn.quickaccess import get_activation

__all__ = ['NeuralLogicInferenceMethod', 'NeuralLogicInferenceBase', 'NeuralLogicInference', 'NeuralLogitsInference']


[docs] class NeuralLogicInferenceMethod(JacEnum): SKIP = 'skip' MLP = 'mlp'
[docs] class NeuralLogicInferenceBase(nn.Module):
[docs] def __init__(self, model, input_dim, output_dim, hidden_dim): super().__init__() self.method = NeuralLogicInferenceMethod.from_string(model) self.input_dim = input_dim self.output_dim = output_dim self.hidden_dim = hidden_dim if self.method is NeuralLogicInferenceMethod.MLP: self.layer = nn.Sequential(MLPLayer(input_dim, output_dim, hidden_dim)) else: raise NotImplementedError('Unknown logic inference method: {}.'.format(self.method))
[docs] def forward(self, input): if self.method is NeuralLogicInferenceMethod.SKIP: return input input_size = input.size()[:-1] input_channel = input.size(-1) f = input.view(-1, input_channel) f = self.layer(f) f = f.view(*input_size, -1) return f
[docs] def get_output_dim(self, input_dim): if self.method is NeuralLogicInferenceMethod.SKIP: return input_dim return self.output_dim
[docs] class NeuralLogicInference(NeuralLogicInferenceBase):
[docs] def __init__(self, model, input_dim, output_dim, hidden_dim, activation='sigmoid'): super().__init__(model, input_dim, output_dim, hidden_dim) if self.method is NeuralLogicInferenceMethod.MLP: self.layer.add_module(str(len(self.layer)), get_activation(activation)) else: raise NotImplementedError('Unknown logic inference method: {}.'.format(self.method))
[docs] class NeuralLogitsInference(NeuralLogicInferenceBase): pass