Source code for jactorch.nn.container
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : container.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/09/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch
import torch.nn as nn
__all__ = ['SequentialN']
[docs]
class SequentialN(nn.Sequential):
[docs]
def forward(self, *inputs, return_all=False):
all_values = [inputs]
for module in self._modules.values():
if torch.is_tensor(inputs):
inputs = [inputs]
inputs = module(*inputs)
all_values.append(inputs)
if return_all:
return inputs, all_values
return inputs