Source code for jactorch.nn.cnn.functional

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : functional.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 04/05/2019
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import collections
from itertools import repeat

import torch.nn.functional as F

from jacinle.utils.enum import JacEnum

__all__ = ['ConvPaddingMode', 'ConvBorderMode', 'compute_padding_shape', 'padding_nd']


[docs] class ConvPaddingMode(JacEnum): DEFAULT = 'default' VALID = 'valid' SAME = 'same' TENSORFLOW = 'tensorflow'
[docs] class ConvBorderMode(JacEnum): ZEROS = 'zeros' CONSTANT = 'constant' REFLECT = 'reflect' REPLICATE = 'replicate'
[docs] @classmethod def from_string(cls, string): x = super(ConvBorderMode, cls).from_string(string) return x
def _format_tuple(val, arity): if isinstance(val, collections.Iterable): return tuple(val) return tuple(repeat(val, arity))
[docs] def compute_padding_shape(input_size, kernel_size, padding, mode): mode = ConvPaddingMode.from_string(mode) if mode is ConvPaddingMode.DEFAULT: return _format_tuple(padding, len(input_size)) elif mode is ConvPaddingMode.VALID: return _format_tuple(0, len(input_size)) elif mode is ConvPaddingMode.SAME: kernel_size = _format_tuple(kernel_size, len(input_size)) assert all(map(lambda x: x % 2 == 1, kernel_size)) return tuple([k // 2 for k in kernel_size]) elif mode == ConvPaddingMode.TENSORFLOW: raise NotImplementedError()
[docs] def padding_nd(input, kernel_size, padding, padding_mode, border_mode, use_pytorch_padding_mode=False): padding_mode = ConvPaddingMode.from_string(padding_mode) border_mode = ConvBorderMode.from_string(border_mode) padding = compute_padding_shape(input.size()[2:], kernel_size, padding, padding_mode) if use_pytorch_padding_mode: return input, padding, border_mode.value if border_mode is ConvBorderMode.CONSTANT: return input, padding if input.dim() == 3: padded = F.pad(input, (padding[0], padding[0]), mode=border_mode.value) elif input.dim() == 4: padded = F.pad(input, (padding[1], padding[1], padding[0], padding[0]), mode=border_mode.value) elif input.dim() == 5: padded = F.pad(input, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]), mode=border_mode.value) else: raise ValueError('Only 4D or 5D inputs are supported.') conv_padding = _format_tuple(0, input.dim() - 2) return padded, conv_padding