Source code for jactorch.vision.conv

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : conv.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/04/2019
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F

from jacinle.utils.argument import get_2dshape
from jactorch.nn.cnn.functional import ConvBorderMode, padding_nd

__all__ = [
    'MaxPoolingKernelDef', 'custom_kernel', 'CustomKernel'
]


[docs] class MaxPoolingKernelDef(object):
[docs] def __init__(self, kernel_size): self.kernel_size = get_2dshape(kernel_size)
[docs] def custom_kernel(image, k, border_mode='zeros'): border_mode = ConvBorderMode.from_string(border_mode) if not isinstance(k, MaxPoolingKernelDef): if not torch.is_tensor(k): k = torch.tensor(k, device=image.device, dtype=torch.float32) if k.dim() == 2: k = k.unsqueeze(0).unsqueeze(0) elif k.dim() == 3: k = k.unsqueeze(1) elif k.dim() == 4: pass else: raise ValueError('Unsupported kernel size: {}.'.format(k.size())) assert k.size(2) % 2 == 1 and k.size(3) % 2 == 1 kernel_shape = k.shape[2:4] padding = (k.shape[2] // 2, k.shape[3] // 2) else: kernel_shape = k.kernel_size assert kernel_shape[0] % 2 == 1 and kernel_shape[1] % 2 == 1 padding = (kernel_shape[0] // 2, kernel_shape[1] // 2) image_dim = image.dim() image_size = image.size() if image_dim not in (2, 3, 4): raise ValueError('Unsupported image dim: {}.'.format(image_dim)) for i in range(4 - image_dim): image = image.unsqueeze(0) image, extra_padding = padding_nd(image, kernel_shape, None, 'same', border_mode) if not isinstance(k, MaxPoolingKernelDef): return F.conv2d( image.contiguous().view((image.shape[0] * image.shape[1], 1) + image.shape[2:]), k, padding=extra_padding ).view(image_size) else: return F.max_pool2d(image, kernel_shape, stride=1, padding=extra_padding).view(image_size)
[docs] class CustomKernel(nn.Module):
[docs] def __init__(self, kernel, border_mode='zero'): super().__init__() if not isinstance(kernel, MaxPoolingKernelDef): if not torch.is_tensor(kernel): kernel = torch.tensor(kernel, dtype=torch.float32) self.register_buffer('kernel', kernel) else: self.kernel = kernel self.border_mode = ConvBorderMode.from_string(border_mode)
[docs] def forward(self, input): return custom_kernel(input, self.kernel, self.border_mode)