Source code for jaclearn.visualize.plot

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : plot.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import io as _io
import numpy as np

from jacinle.image.backend import cv2, Image, opencv_only, pil_only
from jacinle.utils.enum import JacEnum

__all__ = ['plot2opencv', 'plot2pil', 'heatmap2pil']


[docs] @opencv_only def plot2opencv(fig): """Convert a pyplot instance to image""" buf = _io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) rawbuf = np.frombuffer(buf.getvalue(), dtype='uint8') im = cv2.imdecode(rawbuf, cv2.IMREAD_COLOR) buf.close() return im
[docs] @pil_only def plot2pil(fig): canvas = fig.canvas canvas.draw() pil = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb()) return pil
[docs] class HeatmapNormalization(JacEnum): NONE = 'none' RANGE = 'range' INSTANCE = 'instance'
[docs] @pil_only def heatmap2pil(heatmap, normalization='none', minval=0, maxval=1): heatmap = _to_numpy(heatmap) normalization = HeatmapNormalization.from_string(normalization) if normalization is HeatmapNormalization.NONE: pass elif normalization is HeatmapNormalization.MINMAX: heatmap = (heatmap - minval) / (maxval - minval) elif normalization is HeatmapNormalization.INSTANCE: minval, maxval = heatmap.min(), heatmap.max() heatmap = (heatmap - minval) / (maxval - minval) else: raise ValueError('Unknown heatmap normalization: {}.'.format(normalization)) return Image.fromarray((heatmap * 255).astype('uint8'))
def _to_numpy(obj): # NB(Jiayuan Mao @ 05/03): hack for pytorch tensors. if hasattr(obj, 'cpu'): obj = obj.cpu() return np.array(obj)