Source code for jactorch.nn.container

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : container.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 04/09/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn

__all__ = ['SequentialN']


[docs] class SequentialN(nn.Sequential):
[docs] def forward(self, *inputs, return_all=False): all_values = [inputs] for module in self._modules.values(): if torch.is_tensor(inputs): inputs = [inputs] inputs = module(*inputs) all_values.append(inputs) if return_all: return inputs, all_values return inputs