Source code for jactorch.nn.simple

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : simple.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 07/13/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn

__all__ = ['Identity', 'TorchApplyRecorderMixin', 'AutoResetParametersMixin']


[docs] class Identity(nn.Module):
[docs] def forward(self, *args): if len(args) == 1: return args[0] return args
[docs] class TorchApplyRecorderMixin(nn.Module):
[docs] def __init__(self): super().__init__() self._apply_recorder_indicator = nn.Parameter( torch.tensor(0, dtype=torch.float32, device=torch.device('cpu')) ) self._apply_recorder_indicator.requires_grad = False
@property def dtype(self): return self._apply_recorder_indicator.dtype @property def device(self): return self._apply_recorder_indicator.device
[docs] class AutoResetParametersMixin(object):
[docs] def reset_parameters(self): for module in self.modules(): if id(module) != id(self) and hasattr(module, 'reset_parameters'): module.reset_parameters()