#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : residual.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/10/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch
import torch.nn as nn
from .simple import AutoResetParametersMixin
from .rnn_utils import rnn_with_length
from .normalization import LayerNorm
__all__ = ['ResidualConvBlock', 'ResidualConvBottleneck', 'ResidualLinear', 'ResidualGRU']
def _conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
[docs]
class ResidualConvBlock(nn.Module, AutoResetParametersMixin):
expansion = 1
[docs]
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = _conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = _conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
[docs]
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
[docs]
class ResidualConvBottleneck(nn.Module, AutoResetParametersMixin):
expansion = 4
[docs]
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
[docs]
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
[docs]
class ResidualLinear(nn.Module, AutoResetParametersMixin):
[docs]
def __init__(self, hidden_dim, norm1=None, norm2=None):
super().__init__()
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
self.norm1 = norm1
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.norm2 = norm2
self.relu = nn.ReLU(inplace=True)
[docs]
def forward(self, input):
f = self.linear1(input)
if self.norm1 is not None:
f = self.norm1(f)
f = self.relu(f)
f = self.linear2(f)
if self.norm2 is not None:
f = self.norm2(f)
f = f + input
f = self.relu(f)
return f
[docs]
class ResidualGRU(nn.Module, AutoResetParametersMixin):
[docs]
def __init__(self, hidden_dim, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False,
layer_norm=False):
super().__init__()
self.hidden_dim = hidden_dim
self.real_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
self.num_layers = num_layers
self.real_num_layers = num_layers * 2 if bidirectional else num_layers
self.batch_first = batch_first
self.bidirectional = bidirectional
self.rnns = nn.ModuleList([
nn.GRU(self.hidden_dim, self.real_hidden_dim, 1, bias=bias, batch_first=False, dropout=dropout,
bidirectional=bidirectional)
for _ in range(num_layers)
])
self.layer_norms = None
if layer_norm:
self.layer_norms = nn.ModuleList([
LayerNorm(hidden_dim)
for _ in range(num_layers)
])
[docs]
def forward(self, input, input_lengths, initial_states=None):
if self.batch_first:
input = input.transpose(0, 1)
if initial_states is None:
batch_size = input.size(1)
state_shape = (self.real_num_layers, batch_size, self.hidden_dim)
initial_states = torch.zeros(state_shape, device=input.device)
f = input
for i in range(self.num_layers):
f_input = f
f_state = initial_states[2*i:2*i+2] if self.bidirectional else initial_states[i:i+1]
# TODO(Jiayuan Mao @ 05/08): accelerate this by pre-sort the sequences.
f = rnn_with_length(self.rnns[i], f, input_lengths, initial_states=f_state,
batch_first=False, sorted=False)
if self.layer_norms is not None:
f = self.layer_norms[i](f)
f = f + f_input
if self.batch_first:
f = f.transpose(0, 1)
return f