#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : transforms.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/03/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import random
import torch
import torchvision.transforms as transforms
import jactorch.transforms.image as jac_transforms
from . import functional as F
__all__ = ["Compose", "Lambda", "ToTensor", "NormalizeCoor", "DenormalizeCoor", "Normalize", "Resize", "CenterCrop", "Pad",
"RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale",
"PadMultipleOf"]
[docs]
class Compose(transforms.Compose):
[docs]
def __call__(self, img, coor):
for t in self.transforms:
img, coor = t(img, coor)
return img, coor
[docs]
class Lambda(transforms.Lambda):
[docs]
def __call__(self, img, coor):
return self.lambd(img, coor)
[docs]
class ToTensor(transforms.ToTensor):
[docs]
def __call__(self, img, coor):
img, coor = F.refresh_valid(img, coor)
return super().__call__(img), torch.from_numpy(coor)
[docs]
class NormalizeCoor(object):
[docs]
def __call__(self, img, coor):
return F.normalize_coor(img, coor)
[docs]
class DenormalizeCoor(object):
[docs]
def __call__(self, img, coor):
return F.denormalize_coor(img, coor)
[docs]
class Normalize(transforms.Normalize):
[docs]
def __call__(self, img, coor):
return super().__call__(img), coor
[docs]
class Resize(transforms.Resize):
# Assuming coordinates are 0/1-normalized.
[docs]
def __call__(self, img, coor):
return super().__call__(img), coor
[docs]
class CenterCrop(transforms.CenterCrop):
[docs]
def __call__(self, img, coor):
return F.center_crop(img, coor, self.size)
[docs]
class Pad(jac_transforms.Pad):
[docs]
def __call__(self, img, coor):
return F.pad(img, coor, self.padding, self.fill)
[docs]
class RandomCrop(transforms.RandomCrop):
[docs]
def __call__(self, img, coor):
if self.padding > 0:
img = F.pad(img, coor, self.padding)
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, coor, i, j, h, w)
[docs]
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
[docs]
def __call__(self, img, coor):
if random.random() < 0.5:
return F.hflip(img, coor)
return img, coor
[docs]
class RandomVerticalFlip(transforms.RandomVerticalFlip):
[docs]
def __call__(self, img, coor):
if random.random() < 0.5:
return F.vflip(img, coor)
return img, coor
[docs]
class RandomResizedCrop(transforms.RandomResizedCrop):
[docs]
def __call__(self, img, coor):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, coor, i, j, h, w, self.size, self.interpolation)
[docs]
class Grayscale(transforms.Grayscale):
[docs]
def __call__(self, img, coor):
return super().__call__(img), coor
[docs]
class RandomGrayscale(transforms.RandomGrayscale):
[docs]
def __call__(self, img, coor):
return super().__call__(img), coor
[docs]
class ColorJitter(transforms.ColorJitter):
[docs]
def __call__(self, img, coor):
return super().__call__(img), coor
[docs]
class RandomRotation(transforms.RandomRotation):
[docs]
def __init__(self, angle, crop=False, **kwargs):
super().__init__(angle, **kwargs)
self.crop = crop
[docs]
def __call__(self, img, coor):
angle = self.get_params(self.degrees)
return F.rotate(img, coor, angle, self.resample, self.crop, self.expand, self.center)
[docs]
class PadMultipleOf(jac_transforms.PadMultipleOf):
[docs]
def __call__(self, img, coor):
return F.pad_multiple_of(img, coor, self.multiple, mode=self.mode, fill=self.fill)