Source code for jactorch.nn.neural_logic.recurrent_layer

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : recurrent_layer.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 06/10/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn

from jactorch.functional import concat_shape
from .modules.neural_logic import NeuralLogicInferenceMethod
from .modules.dimension import Reducer
from .layer import NeuralLogicMachine, _get_tuple_n

__all__ = ['RecurrentNeuralLogicMachine']


[docs] class RecurrentNeuralLogicMachine(nn.Module):
[docs] def __init__( self, breadth, depth1, depth2, depth3, input_dims, imm_dims1, imm_dims2, output_dims, epsilon, logic_model, logic_hidden_dim, exclude_self=True, pre_residual=False, connections=None ): super().__init__() def add(x, y): """output[i] = x[i] + y[i]""" return [i + j for i, j in zip(x, y)] def mul(x, m): """output[i] = x[i] * m""" return [i * m for i in x] self.breadth = breadth self.input_dims = _get_tuple_n(input_dims, breadth + 1, int) self.imm_dims1 = _get_tuple_n(imm_dims1, breadth + 1, int) self.imm_dims2 = _get_tuple_n(imm_dims2, breadth + 1, int) self.output_dims = _get_tuple_n(output_dims, breadth + 1, int) kwargs = dict(logic_model=logic_model, logic_hidden_dim=logic_hidden_dim, exclude_self=exclude_self, connections=connections) if depth1 > 0: input_dims = self.input_dims output_dims = self.imm_dims1 self.model_pre = NeuralLogicMachine(depth1, breadth, input_dims, output_dims, residual=pre_residual, **kwargs) else: self.add_module('model_pre', None) input_dims = add(self.model_pre.output_dims, mul(self.imm_dims2, 2)) output_dims = mul(self.imm_dims2, 2) kwargs = dict(logic_model=logic_model, logic_hidden_dim=logic_hidden_dim, exclude_self=exclude_self, connections=connections) self.model = NeuralLogicMachine(depth2, breadth, input_dims, output_dims, residual=False, **kwargs) if depth3 > 0: input_dims = self.model.output_dims output_dims = self.output_dims kwargs = dict(logic_model=logic_model, logic_hidden_dim=logic_hidden_dim, exclude_self=exclude_self, connections=connections) self.model_post = NeuralLogicMachine(depth3, breadth, input_dims, output_dims, residual=False, **kwargs) else: self.add_module('model_post', None) self.reducers = nn.ModuleList() for i in range(self.breadth + 1): if i == 0: self.reducers.append(None) else: self.reducers.append(Reducer(i, exclude_self=exclude_self, exists=False)) self.epsilon = epsilon
[docs] def forward(self, inputs, max_depths=1000): for i in inputs: if i is not None: batch_size, n = i.size()[:2] dtype = i.dtype device = i.device break def zeros(d, h): """zeros(batch_size, ..., h)""" if d is None: return None return torch.zeros(concat_shape(batch_size, [n for _ in range(d)], h), dtype=dtype, device=device) def _format(x): for i in x: if i is None: yield i else: yield i.size() def print_list(*args, **kwargs): print(*[list(_format(x)) for x in args], **kwargs) def concat(x, y, z): def _concat(): """output[i] = concat(x[i], y[i], z[i]). x[i] can be None""" for i, j, k in zip(x, y, z): if i is None: yield torch.cat((j, k), dim=-1) else: yield torch.cat((i, j, k), dim=-1) return list(_concat()) def chunk(x, n, dim): def _chunk(): for i in x: yield i.chunk(n, dim=dim) chunks = list(_chunk()) return list(zip(*chunks)) def elem_op(x, y, op): def _add(): """output[i] = x[i] + y[i].""" for i, j in zip(x, y): yield op(i, j) return list(_add()) import functools add = functools.partial(elem_op, op=lambda x, y: x + y) mul = functools.partial(elem_op, op=lambda x, y: x * y) div = functools.partial(elem_op, op=lambda x, y: x / y) def reduce(x, d): """return min(x). d is the dimension of x""" if d == 0: return x.min() return reduce(self.reducers[d](x), d-1) def get_confidence(cs): """compute the confidence""" cs = [reduce(c, i) for i, c in enumerate(cs)] conf = None for c in cs: if conf is None or conf.item() > c.item(): conf = c return conf if self.model_pre is not None: f_pre = self.model_pre(inputs) else: f_pre = inputs fs = [zeros(i, d) for i, d in enumerate(self.imm_dims2)] cs = [zeros(i, d) for i, d in enumerate(self.imm_dims2)] conf = torch.tensor(0, dtype=fs[0].dtype, device=fs[0].device) acc_fs = [zeros(i, d) for i, d in enumerate(self.imm_dims2)] acc_cs = [zeros(i, d) for i, d in enumerate(self.imm_dims2)] for i in range(max_depths): next_fs, next_cs = map(list, chunk(self.model(concat(f_pre, fs, cs)), 2, dim=-1)) for i, (f, c, next_f, next_c) in enumerate(zip(fs, cs, next_fs, next_cs)): mask = (c > self.epsilon).float() next_fs[i] = f * mask + next_f * (1 - mask) next_cs[i] = c * mask + next_c * (1 - mask) # update the feature and confidence fs = next_fs last_cs, cs = cs, add(cs, next_cs) last_conf, conf = conf, get_confidence(cs) acc_fs = add(acc_fs, mul(fs, cs)) acc_cs = add(acc_cs, cs) if conf > self.epsilon: break fs = div(acc_fs, acc_cs) if self.model_post is not None: f_post = self.model_pre(fs) else: f_post = fs return fs, torch.tensor(i, dtype=torch.float), last_cs, last_conf
__hyperparams__ = ( 'breadth', 'depth1', 'depth2', 'depth3', 'input_dims', 'imm_dims1', 'imm_dims2', 'output_dims', 'epsilon', 'logic_model', 'logic_hidden_dim', 'exclude_self', 'pre_residual' ) __hyperparam_defaults__ = { 'logic_model': 'mlp', 'exclude_self': True, 'pre_residual': False, }
[docs] @classmethod def make_prog_block_parser(cls, parser, defaults, prefix=None): for k, v in cls.__hyperparam_defaults__.items(): defaults.setdefault(k, v) if prefix is None: prefix = '--' else: prefix = '--' + str(prefix) + '-' parser.add_argument(prefix + 'breadth', type=int, default=defaults['breadth'], metavar='N') parser.add_argument(prefix + 'depth1', type=int, default=defaults['depth1'], metavar='N') parser.add_argument(prefix + 'depth2', type=int, default=defaults['depth2'], metavar='N') parser.add_argument(prefix + 'depth3', type=int, default=defaults['depth3'], metavar='N') parser.add_argument(prefix + 'imm-dims1', type=int, default=defaults['imm_dims1'], metavar='N') parser.add_argument(prefix + 'imm-dims2', type=int, default=defaults['imm_dims2'], metavar='N') parser.add_argument(prefix + 'epsilon', type=int, default=defaults['epsilon'], metavar='F') parser.add_argument(prefix + 'logic-model', default=defaults['logic_model'], choices=NeuralLogicInferenceMethod.choice_values(), metavar='T') parser.add_argument(prefix + 'logic-hidden-dim', type=int, nargs='+', default=defaults['logic_hidden_dim'], metavar='N') parser.add_argument(prefix + 'exclude-self', type='bool', default=defaults['exclude_self'], metavar='B') parser.add_argument(prefix + 'pre-residual', type='bool', default=defaults['pre_residual'], metavar='B')
[docs] @classmethod def from_args(cls, input_dims, output_dims, args, prefix=None, **kwargs): if prefix is None: prefix = '' else: prefix = str(prefix) + '_' setattr(args, prefix + 'input_dims', input_dims) setattr(args, prefix + 'output_dims', output_dims) init_params = {k: getattr(args, prefix + k) for k in cls.__hyperparams__} init_params.update(kwargs) return cls(**init_params)