#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : imgproc.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import math
import functools
import numpy as np
from jacinle.utils.argument import get_2dshape
from jacinle.utils.enum import JacEnum
from . import backend
__all__ = [
'resize', 'resize_wh', 'resize_scale', 'resize_scale_wh', 'resize_minmax',
'crop', 'center_crop', 'leftup_crop',
'dimshuffle',
'clip', 'clip_decorator',
'grayscale',
'brightness', 'contrast', 'saturation'
]
def _get_crop2d_rest(img, target_shape):
source_shape = img.shape[:2]
target_shape = get_2dshape(target_shape)
rest_shape = source_shape[0] - target_shape[0], source_shape[1] - target_shape[1]
assert rest_shape[0] >= 0 and rest_shape[1] >= 0
return rest_shape
def _crop2d(img, start, size):
return img[start[0]:start[0] + size[0], start[1]:start[1] + size[1]]
[docs]
def resize(img, size, interpolation='LINEAR'):
size = get_2dshape(size)
return backend.resize(img, (size[1], size[0]), interpolation=interpolation)
[docs]
def resize_wh(img, size_wh, interpolation='LINEAR'):
size_wh = get_2dshape(size_wh)
return backend.resize(img, size_wh, interpolation=interpolation)
[docs]
def resize_scale(img, scale, interpolation='LINEAR'):
scale = get_2dshape(scale, type=float)
new_size = math.ceil(img.shape[0] * scale[0]), math.ceil(img.shape[1] * scale[1])
return resize(img, new_size, interpolation=interpolation)
[docs]
def resize_scale_wh(img, scale_wh, interpolation='LINEAR'):
scale_wh = get_2dshape(scale_wh, type=float)
return resize_scale(img, (scale_wh[1], scale_wh[0]), interpolation=interpolation)
[docs]
def resize_minmax(img, min_dim, max_dim=None, interpolation='LINEAR'):
if max_dim is None:
max_dim = min_dim
min_dim, max_dim = min(min_dim, max_dim), max(min_dim, max_dim)
h, w = img.shape[:2]
short, long = min(h, w), max(h, w)
scale = min_dim / short
scale = min(max_dim / long, scale)
return resize_scale(img, scale, interpolation=interpolation)
[docs]
def crop(image, l, t, w, h, extra_crop=None):
if extra_crop is not None and extra_crop != 1:
new_w, new_h = round(w * extra_crop), round(h * extra_crop)
l -= (new_w - w) // 2
t -= (new_h - h) // 2
w, h = new_w, new_h
im_h, im_w = image.shape[0:2]
w, h = int(round(w)), int(round(h))
l, t = int(math.floor(l)), int(math.floor(t))
# range is expected to be image[t:t+h, l:l+w] now.
ex_l, ex_t, ex_w, ex_h = l, t, w, h
delta_l, delta_t = 0, 0
if ex_l < 0:
ex_l = 0
delta_l = ex_l - l
ex_w -= delta_l
if ex_t < 0:
ex_t = 0
delta_t = ex_t - t
ex_h -= delta_t
if ex_l + ex_w > im_w:
ex_w = im_w - ex_l
if ex_t + ex_h > im_h:
ex_h = im_h - ex_t
result = np.zeros(shape=(h, w) + image.shape[2:], dtype=image.dtype)
result[delta_t:delta_t + ex_h, delta_l:delta_l + ex_w] = image[ex_t:ex_t + ex_h, ex_l:ex_l + ex_w]
return result
[docs]
def center_crop(img, target_shape):
""" center crop """
target_shape = get_2dshape(target_shape)
rest = _get_crop2d_rest(img, target_shape)
start = rest[0] // 2, rest[1] // 2
return _crop2d(img, start, target_shape)
[docs]
def leftup_crop(img, target_shape):
""" left-up crop """
start = 0, 0
target_shape = get_2dshape(target_shape)
return _crop2d(img, start, target_shape)
[docs]
class ShuffleType(JacEnum):
CHANNEL_FIRST = 'channel_first'
CHANNEL_LAST = 'channel_last'
[docs]
def dimshuffle(img, shuffle_type):
shuffle_type = ShuffleType.from_string(shuffle_type)
assert len(img.shape) in (2, 3, 4), 'Image should be of dims 2, 3 or 4'
if len(img.shape) == 2:
return img
elif len(img.shape) == 3:
if shuffle_type == ShuffleType.CHANNEL_FIRST:
return np.transpose(img, (2, 0, 1))
else:
return np.transpose(img, (1, 2, 0))
else: # len(img.shape) == 4:
if shuffle_type == ShuffleType.CHANNEL_FIRST:
return np.transpose(img, (0, 3, 1, 2))
else:
return np.transpose(img, (0, 2, 3, 1))
[docs]
def clip(img):
return np.minimum(255, np.maximum(0, img))
[docs]
def clip_decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
img = func(*args, **kwargs)
return clip(img)
return new_func
[docs]
def grayscale(img):
assert len(img.shape) == 3 and img.shape[2] == 3
w = np.array([0.114, 0.587, 0.299]).reshape(1, 1, 3)
img = (img * w).sum(axis=2, keepdims=True)
return img
[docs]
@clip_decorator
def brightness(img, alpha):
return img * alpha
[docs]
@clip_decorator
def contrast(img, alpha):
gs = grayscale(img)
gs[:] = gs.mean()
img = img * alpha + gs * (1 - alpha)
return img
[docs]
@clip_decorator
def saturation(img, alpha):
gs = grayscale(img)
img = img * alpha + gs * (1 - alpha)
return img