Source code for jactorch.nn.quickaccess

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : quickaccess.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/28/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch.nn as nn

from . import sync_batchnorm as sync_bn
from .simple import Identity

__all__ = ['get_batcnnorm', 'get_dropout', 'get_activation']


[docs] def get_batcnnorm(bn, nr_features=None, nr_dims=1): if isinstance(bn, nn.Module): return bn assert 1 <= nr_dims <= 3 if bn in (True, 'async'): clz_name = 'BatchNorm{}d'.format(nr_dims) return getattr(nn, clz_name)(nr_features) elif bn == 'sync': clz_name = 'SynchronizedBatchNorm{}d'.format(nr_dims) return getattr(sync_bn, clz_name)(nr_features) else: raise ValueError('Unknown type of batch normalization: {}.'.format(bn))
[docs] def get_dropout(dropout, nr_dims=1): if isinstance(dropout, nn.Module): return dropout if dropout is True: dropout = 0.5 if nr_dims == 1: return nn.Dropout(dropout, True) else: clz_name = 'Dropout{}d'.format(nr_dims) return getattr(nn, clz_name)(dropout)
[docs] def get_activation(act): if isinstance(act, nn.Module): return act assert type(act) is str, 'Unknown type of activation: {}.'.format(act) act_lower = act.lower() if act_lower == 'identity': return Identity() elif act_lower == 'relu': return nn.ReLU(True) elif act_lower == 'sigmoid': return nn.Sigmoid() elif act_lower == 'tanh': return nn.Tanh() else: try: return getattr(nn, act) except AttributeError: raise ValueError('Unknown activation function: {}.'.format(act))