Source code for jactorch.transforms.bbox.transforms

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : transforms.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/03/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import random

import torch
import torchvision.transforms as transforms

import jactorch.transforms.image as jac_transforms

from . import functional as F

__all__ = ["Compose", "Lambda", "ToTensor", "NormalizeBbox", "DenormalizeBbox", "Normalize", "Resize", "CenterCrop", "Pad",
           "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
           "LinearTransformation", "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale",
           "PadMultipleOf"]


[docs] class Compose(transforms.Compose):
[docs] def __call__(self, img, bbox): for t in self.transforms: img, bbox = t(img, bbox) return img, bbox
[docs] class Lambda(transforms.Lambda):
[docs] def __call__(self, img, bbox): return self.lambd(img, bbox)
[docs] class ToTensor(transforms.ToTensor):
[docs] def __call__(self, img, bbox): # TODO(Jiayuan Mao @ 07/23): check whether bboxes are out of the image. return super().__call__(img), torch.from_numpy(bbox)
[docs] class NormalizeBbox(object):
[docs] def __call__(self, img, bbox): return F.normalize_bbox(img, bbox)
[docs] class DenormalizeBbox(object):
[docs] def __call__(self, img, bbox): return F.denormalize_bbox(img, bbox)
[docs] class Normalize(transforms.Normalize):
[docs] def __call__(self, img, bbox): return super().__call__(img), bbox
[docs] class Resize(transforms.Resize): # Assuming bboxdinates are 0/1-normalized.
[docs] def __call__(self, img, bbox): return super().__call__(img), bbox
[docs] class CenterCrop(transforms.CenterCrop):
[docs] def __call__(self, img, bbox): return F.center_crop(img, bbox, self.size)
[docs] class Pad(transforms.Pad):
[docs] def __call__(self, img, bbox): return F.pad(img, bbox, self.padding, self.fill)
[docs] class RandomCrop(transforms.RandomCrop):
[docs] def __call__(self, img, bbox): if self.padding > 0: img = F.pad(img, bbox, self.padding) i, j, h, w = self.get_params(img, self.size) return F.crop(img, bbox, i, j, h, w)
[docs] class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
[docs] def __call__(self, img, bbox): if random.random() < 0.5: return F.hflip(img, bbox) return img, bbox
[docs] class RandomVerticalFlip(transforms.RandomVerticalFlip):
[docs] def __call__(self, img, bbox): if random.random() < 0.5: return F.vflip(img, bbox) return img, bbox
[docs] class RandomResizedCrop(transforms.RandomResizedCrop):
[docs] def __call__(self, img, bbox): i, j, h, w = self.get_params(img, self.scale, self.ratio) return F.resized_crop(img, bbox, i, j, h, w, self.size, self.interpolation)
[docs] class Grayscale(transforms.Grayscale):
[docs] def __call__(self, img, bbox): return super().__call__(img), bbox
[docs] class RandomGrayscale(transforms.RandomGrayscale):
[docs] def __call__(self, img, bbox): return super().__call__(img), bbox
[docs] class LinearTransformation(transforms.LinearTransformation):
[docs] def __call__(self, tensor, bbox): return super().__call__(tensor), bbox
[docs] class ColorJitter(transforms.ColorJitter):
[docs] def __call__(self, img, bbox): return super().__call__(img), bbox
[docs] class RandomRotation(transforms.RandomRotation):
[docs] def __call__(self, img, bbox): assert self.degrees[0] == self.degrees[1] == 0 angle = self.get_params(self.degrees) return F.rotate(img, bbox, angle, self.resample, self.expand, self.center)
[docs] class PadMultipleOf(jac_transforms.PadMultipleOf):
[docs] def __call__(self, img, coor): return F.pad_multiple_of(img, coor, self.multiple)