Source code for jactorch.nn.normalization

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : normalization.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 04/10/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn

from jacinle.utils.numeric import prod

__all__ = ['LayerNorm']


def _unsqueeze_ft(tensor):
    return tensor.unsqueeze(0).unsqueeze(-1)


[docs] class LayerNorm(nn.Module):
[docs] def __init__(self, num_features, dim=-1, eps=1e-5, affine=True): super().__init__() self.eps = eps self.affine = affine self.dim = dim if self.affine: self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features))
[docs] def reset_parameters(self): if self.affine: self.weight.data.fill_(1) self.bias.data.zero_()
[docs] def forward(self, input): dim = self.dim if dim < 0: dim += input.dim() input_shape = input.size() imm_shape = (prod(input_shape[:dim]), input_shape[dim], prod(input_shape[dim+1:])) input = input.view(imm_shape) mean = input.mean(1, keepdim=True) std = input.std(1, keepdim=True, unbiased=False) # Compute the output. if self.affine: output = (input - mean) * (_unsqueeze_ft(self.weight) / std) + _unsqueeze_ft(self.bias) else: output = (input - mean) / std return output.view(input_shape)