Source code for jaclearn.visualize.imgrid

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : imgrid.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, List


__all__ = ['image_grid', 'auto_image_grid_mplib']


[docs] def image_grid(all_images, grid_desc): """ Create a image grid given the description. The description is a list of axis desc, of format: %d[h|v]. If the first number n is a positive number, every n images will be concatenated horizontally or vertically. We allow exactly one axis desc to be only [h|v], meaning the number of images of that axis will be automatically inferred. Args: all_images: A list of images. Should be np.ndarray of shape (h, w, c). grid_desc: The grid description. Returns: A single big image created. """ axes_info = [] auto_infer_dim = len(all_images) for d in grid_desc: if d == 'h' or d == 'v': axes_info.append([None, d]) else: assert d.endswith('h') or d.endswith('v'), d length, axis = int(d[:-1]), d[-1] assert auto_infer_dim % length == 0, 'Length of all_images should be divided by axes_info, ' \ 'got {} and {}'.format(len(all_images), grid_desc) axes_info.append((length, axis)) auto_infer_dim //= length for i, info in enumerate(axes_info): if info[0] is None: axes_info[i] = (auto_infer_dim, info[1]) def stack(i, sequence): axis = axes_info[i][1] if len(sequence) == 1: return sequence[0] if axis == 'h': return np.hstack(sequence) return np.vstack(sequence) n = len(axes_info) def recursive_concat(i, sequence): if i == n - 1: return stack(i, sequence) nr_parts = axes_info[i][0] length = len(sequence) // nr_parts parts = [] for j in range(nr_parts): part = recursive_concat(i+1, sequence[j*length:(j+1)*length]) parts.append(part) return stack(i, parts) return recursive_concat(0, all_images)
[docs] def auto_image_grid_mplib(images: List[np.ndarray], images_title: Optional[List[str]] = None, global_title: Optional[str] = None, show: bool = True): """ Automatically create a grid for the images. Args: images: a list of images. Should be np.ndarray of shape (h, w, c). images_title: the title for each image. global_title: the title for the whole image grid. show: whether to show the image grid using plt.show(). Returns: the figure object. """ n = len(images) if n in AUTO_IMAGE_GRID_DESC: nr_rows, nr_cols = AUTO_IMAGE_GRID_DESC[n] else: nr_cols = 5 nr_rows = (n + nr_cols - 1) // nr_cols fig, axes = plt.subplots(nr_rows, nr_cols, figsize=(nr_cols*4, nr_rows*4)) for i, image, title in zip(range(n), images, images_title if images_title is not None else [None] * n): ax = axes[i // nr_cols, i % nr_cols] ax.imshow(image) if title is not None: ax.set_title(title) ax.axis('off') if global_title is not None: # Use bold font for the global title fig.suptitle(global_title, fontweight='bold') fig.tight_layout() if show: plt.show() return fig
AUTO_IMAGE_GRID_DESC = { 1: (1, 1), 2: (1, 2), 3: (1, 3), 4: (2, 2), 5: (1, 5), 6: (2, 3), 7: (2, 4), 8: (2, 4), 9: (2, 5), 10: (2, 5), }