Source code for jactorch.nn.cnn.coord_conv
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : coord_conv.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 08/04/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
"""
CoordConv introduced in:
"An intriguing failing of convolutional neural networks and the CoordConv solution".
https://arxiv.org/pdf/1807.03247.pdf
Codes are adapted from https://github.com/mkocabas/CoordConv-pytorch.
"""
import torch
import torch.nn as nn
import jactorch
__all__ = ['CoordConv2D', 'CoordConv']
[docs]
class CoordConv2D(nn.Module):
[docs]
def __init__(self, in_channels, out_channels, *args, use_radius=False, **kwargs):
super().__init__()
self.addcoords = _AddCoords(use_radius=use_radius)
in_channels += self.addcoords.extra_channels
self.conv = nn.Conv2d(in_channels, out_channels, *args, **kwargs)
[docs]
def forward(self, x):
f = self.addcoords(x)
f = self.conv(f)
return f
# NB(Jiayuan Mao @ 01/16): legacy reasons. Currently we only support 2D coord conv.
CoordConv = CoordConv2D
class _AddCoords(nn.Module):
def __init__(self, use_radius=False):
super().__init__()
self.use_radius = use_radius
self.extra_channels = 3 if self.use_radius else 2
def forward(self, input):
batch_size, _, h, w = input.size()
def gen(length):
return -1 + torch.arange(length, dtype=input.dtype, device=input.device) / (length - 1) * 2
results = [input]
with torch.no_grad():
x_coords = gen(w).view(1, 1, 1, w).expand((batch_size, 1, h, w))
y_coords = gen(h).view(1, 1, h, 1).expand((batch_size, 1, h, w))
results.extend([x_coords, y_coords])
if self.use_radius:
radius = torch.sqrt(torch.pow(x_coords - 0.5, 2) + torch.pow(y_coords - 0.5, 2))
results.append(radius)
return torch.cat(results, dim=1)