Source code for jactorch.nn.quickaccess
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : quickaccess.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/28/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch.nn as nn
from . import sync_batchnorm as sync_bn
from .simple import Identity
__all__ = ['get_batcnnorm', 'get_dropout', 'get_activation']
[docs]
def get_batcnnorm(bn, nr_features=None, nr_dims=1):
if isinstance(bn, nn.Module):
return bn
assert 1 <= nr_dims <= 3
if bn in (True, 'async'):
clz_name = 'BatchNorm{}d'.format(nr_dims)
return getattr(nn, clz_name)(nr_features)
elif bn == 'sync':
clz_name = 'SynchronizedBatchNorm{}d'.format(nr_dims)
return getattr(sync_bn, clz_name)(nr_features)
else:
raise ValueError('Unknown type of batch normalization: {}.'.format(bn))
[docs]
def get_dropout(dropout, nr_dims=1):
if isinstance(dropout, nn.Module):
return dropout
if dropout is True:
dropout = 0.5
if nr_dims == 1:
return nn.Dropout(dropout, True)
else:
clz_name = 'Dropout{}d'.format(nr_dims)
return getattr(nn, clz_name)(dropout)
[docs]
def get_activation(act):
if isinstance(act, nn.Module):
return act
assert type(act) is str, 'Unknown type of activation: {}.'.format(act)
act_lower = act.lower()
if act_lower == 'identity':
return Identity()
elif act_lower == 'relu':
return nn.ReLU(True)
elif act_lower == 'sigmoid':
return nn.Sigmoid()
elif act_lower == 'tanh':
return nn.Tanh()
else:
try:
return getattr(nn, act)
except AttributeError:
raise ValueError('Unknown activation function: {}.'.format(act))