#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : transforms.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 12/18/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import random
import numbers
import collections
from PIL import Image
import torch
import torchvision.transforms.transforms as torch_transforms
from jacinle.utils.argument import get_2dshape
from jacinle.utils.enum import JacEnum
from jacinle.utils.defaults import defaults_manager
from .functional import image as fimage
from .functional import coor as fcoor
from .functional import bbox as fbbox
from .functional._utils import get_rotation_matrix, get_size_multiple_of
__all__ = [
"TransformDataTypes", "TransformGuide", "TransformBase", "TransformFunctionBase", "TransformFunctionBaseImageOnly",
"Compose", "Lambda", "RandomApply", "RandomOrder", "RandomChoice",
"ToTensor", "ToPILImage", "Normalize", "NormalizeCoordinates", "DenormalizeCoordinates",
"Crop", "CenterCrop", "RandomCrop", "Pad", "PadMultipleOf",
"HFlip", "VFlip", "RandomHorizontalFlip", "RandomVerticalFlip",
"Resize", "ResizeMultipleOf", "RandomResizedCrop",
"FiveCrop", "TenCrop",
"Rotate", "RandomRotation",
"LinearTransformation", "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale",
]
default_transform_guide = TransformGuide({
'image': {'type': 'image'},
'coor': {'type': 'coor', 'dep': ['image']},
'bbox': {'type': 'bbox', 'dep': ['image']}
})
get_default_transform_guide = defaults_manager.gen_get_default(TransformGuide, lambda: default_transform_guide)
[docs]
class Compose(torch_transforms.Compose):
[docs]
def __call__(self, feed_dict=None, **kwargs):
if feed_dict is not None and not isinstance(feed_dict, collections.Mapping):
return self.ezcall(feed_dict, **kwargs)
feed_dict = feed_dict or {}
feed_dict.update(**kwargs)
feed_dict = super().__call__(feed_dict)
return feed_dict
ezcall = TransformBase.ezcall
[docs]
class RandomApply(torch_transforms.RandomApply):
[docs]
def __call__(self, feed_dict=None, **kwargs):
if feed_dict is not None and not isinstance(feed_dict, collections.Mapping):
return self.ezcall(feed_dict, **kwargs)
feed_dict = feed_dict or {}
feed_dict.update(**kwargs)
feed_dict = super().__call__(feed_dict)
return feed_dict
ezcall = TransformBase.ezcall
[docs]
class RandomOrder(torch_transforms.RandomOrder):
[docs]
def __call__(self, feed_dict=None, **kwargs):
if feed_dict is not None and not isinstance(feed_dict, collections.Mapping):
return self.ezcall(feed_dict, **kwargs)
feed_dict = feed_dict or {}
feed_dict.update(**kwargs)
feed_dict = super().__call__(feed_dict)
return feed_dict
ezcall = TransformBase.ezcall
[docs]
class RandomChoice(torch_transforms.RandomChoice):
[docs]
def __call__(self, feed_dict=None, **kwargs):
if feed_dict is not None and not isinstance(feed_dict, collections.Mapping):
return self.ezcall(feed_dict, **kwargs)
feed_dict = feed_dict or {}
feed_dict.update(**kwargs)
feed_dict = super().__call__(feed_dict)
return feed_dict
[docs]
class Lambda(torch_transforms.Lambda):
[docs]
def __call__(self, feed_dict=None, **kwargs):
if feed_dict is not None and not isinstance(feed_dict, collections.Mapping):
return self.ezcall(feed_dict, **kwargs)
feed_dict = feed_dict or {}
feed_dict.update(**kwargs)
feed_dict = super().__call__(feed_dict)
return feed_dict
ezcall = TransformBase.ezcall
[docs]
class ToTensor(TransformFunctionBase):
[docs]
def call_image(self, img):
return fimage.to_tensor(img)
[docs]
def call_coor(self, img, coor):
coor = fcoor.refresh_valid(img, coor)
return torch.tensor(coor)
[docs]
def call_bbox(self, img, bbox):
bbox = fbbox.refresh_valid(img, bbox)
return torch.tensor(bbox)
__doc__ = torch_transforms.ToTensor.__doc__
__repr__ = torch_transforms.ToTensor.__repr__
[docs]
class ToPILImage(TransformFunctionBaseImageOnly):
[docs]
def __init__(self, mode=None, tg=None):
super().__init__(tg)
self.mode = mode
[docs]
def call_image(self, img):
return fimage.to_pil_image(img, self.mode)
__doc__ = torch_transforms.ToPILImage.__doc__
__repr__ = torch_transforms.ToPILImage.__repr__
[docs]
class Normalize(TransformFunctionBaseImageOnly):
[docs]
def __init__(self, mean, std, tg=None):
super().__init__(tg)
self.mean = mean
self.std = std
[docs]
def call_image(self, tensor):
return fimage.normalize(tensor, self.mean, self.std)
__doc__ = torch_transforms.Normalize.__doc__
__repr__ = torch_transforms.Normalize.__repr__
[docs]
class NormalizeCoordinates(TransformFunctionBase):
[docs]
def call_image(self, img):
return img
[docs]
def call_coor(self, img, coor):
return fcoor.normalize_coor(img, coor)
[docs]
def call_bbox(self, img, bbox):
return fbbox.normalize_bbox(img, bbox)
[docs]
class DenormalizeCoordinates(TransformFunctionBase):
[docs]
def call_image(self, img):
return img
[docs]
def call_coor(self, img, coor):
return fcoor.denormalize_coor(img, coor)
[docs]
def call_bbox(self, img, bbox):
return fbbox.denormalize_box(img, bbox)
[docs]
class Crop(TransformFunctionBase):
[docs]
def __init__(self, x, y, w, h, tg=None):
super().__init__(tg)
self.x = x
self.y = y
self.w = w
self.h = h
[docs]
def call_image(self, img):
return fimage.crop(img, self.x, self.y, self.w, self.h)
[docs]
def call_coor(self, img, coor):
return fcoor.crop(coor, self.x, self.y, self.w, self.h)
[docs]
def call_bbox(self, img, bbox):
return fbbox.crop(bbox, self.x, self.y, self.w, self.h)
[docs]
class CenterCrop(TransformBase):
[docs]
def __init__(self, size, tg=None):
super().__init__(tg)
self.size = get_2dshape(size)
[docs]
def call_feed_dict(self, feed_dict):
img = self._get_image(feed_dict)
w, h = img.size
tw, th = self.size
x = int(round((w - tw) / 2.))
y = int(round((h - th) / 2.))
return Crop(x, y, tw, th, tg=self.transform_guide)(feed_dict)
__doc__ = torch_transforms.CenterCrop.__doc__
__repr__ = torch_transforms.CenterCrop.__repr__
[docs]
class RandomCrop(TransformBase):
[docs]
def __init__(self, size, padding=0, pad_if_needed=False, tg=None):
super().__init__(tg)
self.size = get_2dshape(size)
self.padding = padding
self.pad_if_needed = pad_if_needed
[docs]
def call_feed_dict(self, feed_dict):
if self.padding > 0:
feed_dict = Pad(self.padding, tg=self.transform_guide)(feed_dict)
img = self._get_image(feed_dict)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
feed_dict = Pad((int((1 + self.size[1] - img.size[0]) / 2), 0), tg=self.transform_guide)(feed_dict)
img = self._get_image(feed_dict)
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
feed_dict = Pad((0, int((1 + self.size[0] - img.size[1]) / 2)), tg=self.transform_guide)(feed_dict)
i, j, h, w = torch_transforms.RandomCrop.get_params(img, self.size)
return Crop(j, i, w, h, tg=self.transform_guide)(feed_dict)
__doc__ = torch_transforms.RandomCrop.__doc__
__repr__ = torch_transforms.RandomCrop.__repr__
[docs]
class Pad(TransformFunctionBase):
[docs]
def __init__(self, padding, mode='constant', fill=0, tg=None):
super().__init__(tg)
assert isinstance(padding, (numbers.Number, tuple))
assert isinstance(fill, (numbers.Number, str, tuple))
assert mode in ['constant', 'edge', 'reflect', 'symmetric']
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
self.padding = padding
self.fill = fill
self.padding_mode = mode
[docs]
def call_image(self, img):
return fimage.pad(img, self.padding, self.padding_mode, self.fill)
[docs]
def call_coor(self, img, coor):
return fcoor.pad(coor, self.padding)
[docs]
def call_bbox(self, img, bbox):
return fbbox.pad(bbox, self.padding)
__doc__ = torch_transforms.Pad.__doc__
__repr__ = torch_transforms.Pad.__repr__
[docs]
class PadMultipleOf(TransformBase):
[docs]
def __init__(self, multiple, residual=0, mode='constant', fill=0, tg=None):
super().__init__(tg)
self.multiple = multiple
self.residual = residual
self.mode = mode
self.fill = fill
[docs]
def call_feed_dict(self, feed_dict):
img = self._get_image(feed_dict)
h, w = img.height, img.width
hh, ww = get_size_multiple_of(h, w, self.multiple, self.residual)
if h != hh or w != ww:
feed_dict = Pad((0, 0, ww - w, hh - h), mode=self.mode, fill=self.fill, tg=self.transform_guide)(feed_dict)
return feed_dict
[docs]
class HFlip(TransformFunctionBase):
[docs]
def call_image(self, img):
return fimage.hflip(img)
[docs]
def call_coor(self, img, coor):
return fcoor.hflip(img, coor)
[docs]
def call_bbox(self, img, bbox):
return fbbox.hflip(img, bbox)
[docs]
class VFlip(TransformFunctionBase):
[docs]
def call_image(self, img):
return fimage.vflip(img)
[docs]
def call_coor(self, img, coor):
return fcoor.vflip(img, coor)
[docs]
def call_bbox(self, img, bbox):
return fbbox.vflip(img, bbox)
[docs]
class RandomHorizontalFlip(TransformBase):
[docs]
def __init__(self, p=0.5, tg=None):
super().__init__(tg)
self.p = p
[docs]
def call_feed_dict(self, feed_dict):
if random.random() < self.p:
return HFlip(tg=self.transform_guide)(feed_dict)
return feed_dict
__doc__ = torch_transforms.RandomHorizontalFlip.__doc__
__repr__ = torch_transforms.RandomHorizontalFlip.__repr__
[docs]
class RandomVerticalFlip(TransformBase):
[docs]
def __init__(self, p=0.5, tg=None):
super().__init__(tg)
self.p = p
[docs]
def call_feed_dict(self, feed_dict):
if random.random() < self.p:
return VFlip(tg=self.transform_guide)(feed_dict)
return feed_dict
__doc__ = torch_transforms.RandomVerticalFlip.__doc__
__repr__ = torch_transforms.RandomVerticalFlip.__repr__
[docs]
class Resize(TransformFunctionBase):
[docs]
def __init__(self, size, interpolation=Image.BILINEAR, tg=None):
super().__init__(tg)
self.size = get_2dshape(size)
self.interpolation = interpolation
[docs]
def call_image(self, img):
return fimage.resize(img, self.size, self.interpolation)
[docs]
def call_coor(self, img, coor):
return fcoor.resize(img, coor, self.size)
[docs]
def call_bbox(self, img, bbox):
return fbbox.resize(img, bbox, self.size)
__doc__ = torch_transforms.Resize.__doc__
__repr__ = torch_transforms.Resize.__repr__
[docs]
class ResizeMultipleOf(TransformBase):
[docs]
def __init__(self, multiple, residual=0, interpolation=Image.NEAREST, tg=None):
super().__init__(tg)
self.multiple = multiple
self.residual = residual
self.interpolation = interpolation
[docs]
def call_feed_dict(self, feed_dict):
img = self._get_image(feed_dict)
h, w = img.height, img.width
hh, ww = get_size_multiple_of(h, w, self.multiple, self.residual)
if h != hh or w != ww:
feed_dict = Resize((hh, ww), interpolation=self.interpolation, tg=self.transform_guide)(feed_dict)
return feed_dict
[docs]
class RandomResizedCrop(TransformBase):
[docs]
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR, tg=None):
super().__init__(tg)
self.size = (size, size)
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
[docs]
def call_feed_dict(self, feed_dict):
img = self._get_image(feed_dict)
i, j, h, w = torch_transforms.RandomResizedCrop.get_params(img, self.scale, self.ratio)
feed_dict = Crop(j, i, w, h, tg=self.transform_guide)(feed_dict)
feed_dict = Resize(self.size, self.interpolation, tg=self.transform_guide)(feed_dict)
return feed_dict
__doc__ = torch_transforms.RandomResizedCrop.__doc__
__repr__ = torch_transforms.RandomResizedCrop.__repr__
[docs]
class FiveCrop(TransformFunctionBase):
[docs]
def __init__(self, size, tg=None):
super().__init__(tg)
self.size = get_2dshape(size)
[docs]
def call_image(self, img):
return fimage.five_crop(img, self.size)
__doc__ = torch_transforms.FiveCrop.__doc__
__repr__ = torch_transforms.FiveCrop.__repr__
[docs]
class TenCrop(TransformFunctionBase):
[docs]
def __init__(self, size, tg=None):
super().__init__(tg)
self.size = get_2dshape(size)
[docs]
def call_image(self, img):
return fimage.ten_crop(img, self.size)
__doc__ = torch_transforms.TenCrop.__doc__
__repr__ = torch_transforms.TenCrop.__repr__
class _AffineHelper(TransformFunctionBase):
def __init__(self, owner, matrix, tg):
super().__init__(tg)
self.owner = owner
self.matrix = matrix
def call_image(self, img):
return img.rotate(
self.owner.angle,
resample=self.owner.resample, expand=self.owner.expand,
center=self.owner.center, translate=self.owner.translate
)
def call_coor(self, img, coor):
return fcoor.affine(coor, self.matrix)
def call_bbox(self, img, bbox):
return fbbox.affine(bbox, self.matrix)
[docs]
class Rotate(TransformBase):
[docs]
def __init__(self, angle, resample=False, crop=False, expand=False, center=None, translate=None, tg=None):
super().__init__(tg)
self.angle = angle
self.resample = resample
self.crop = crop
self.expand = expand
self.center = center
self.translate = translate
[docs]
def call_feed_dict(self, feed_dict):
img = self._get_image(feed_dict)
matrix, extra_crop = get_rotation_matrix(img, self.angle, self.crop, self.expand, self.center, self.translate)
feed_dict = _AffineHelper(self, matrix, tg=self.transform_guide)(feed_dict)
if extra_crop is not None:
feed_dict = Crop(*extra_crop, tg=self.transform_guide)(feed_dict)
return feed_dict
[docs]
class RandomRotation(TransformBase):
[docs]
def __init__(self, degrees, resample=False, crop=False, expand=False, center=None, translate=None, tg=None):
super().__init__(tg)
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
self.crop = crop
self.resample = resample
self.expand = expand
self.center = center
self.translate = translate
[docs]
def call_feed_dict(self, feed_dict):
angle = torch_transforms.RandomRotation.get_params(self.degrees)
return Rotate(angle, self.resample, self.crop, self.expand, self.center, self.translate, tg=self.transform_guide)(feed_dict)
__doc__ = torch_transforms.RandomRotation.__doc__
def __repr__(self):
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
format_string += ', resample={0}'.format(self.resample)
format_string += ', crop={0}'.format(self.crop)
format_string += ', expand={0}'.format(self.expand)
if self.center is not None:
format_string += ', center={0}'.format(self.center)
if self.translate is not None:
format_string += ', translate={0}'.format(self.translate)
format_string += ')'
return format_string
[docs]
class ColorJitter(TransformFunctionBaseImageOnly):
[docs]
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, tg=None):
super().__init__(tg)
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
[docs]
def call_image(self, img):
return torch_transforms.ColorJitter(self.brightness, self.contrast, self.saturation, self.hue)(img)
[docs]
class Grayscale(TransformFunctionBaseImageOnly):
[docs]
def __init__(self, num_output_channels=1, tg=None):
super().__init__(tg)
self.num_output_channels = num_output_channels
[docs]
def call_image(self, img):
return fimage.to_grayscale(img, num_output_channels=self.num_output_channels)
__doc__ = torch_transforms.Grayscale.__doc__
__repr__ = torch_transforms.Grayscale.__repr__
[docs]
class RandomGrayscale(TransformFunctionBaseImageOnly):
[docs]
def __init__(self, p=0.1, tg=None):
super().__init__(tg)
self.p = p
[docs]
def call_image(self, img):
num_output_channels = 1 if img.mode == 'L' else 3
if random.random() < self.p:
return fimage.to_grayscale(img, num_output_channels=num_output_channels)
return img
__doc__ = torch_transforms.RandomGrayscale.__doc__
__repr__ = torch_transforms.RandomGrayscale.__repr__