Source code for jactorch.train.tb

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : tb.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 05/09/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import numpy as np
from PIL import Image
try:
    from StringIO import StringIO as BytesIO  # Python 2.7
except ImportError:
    from io import BytesIO  # Python 3.x

import tensorflow
if str(tensorflow.__version__) >= '1.14.0':
    import tensorflow.compat.v1 as tf

    if hasattr(tf, 'disable_eager_execution'):
        tf.disable_eager_execution()
else:
    import tensorflow as tf

from jacinle.utils.meter import GroupMeters


[docs] class TBLogger(object): """A simple wrapper for tensorboard. Source: https://raw.githubusercontent.com/SherlockLiao/pytorch-beginner/ """
[docs] def __init__(self, log_dir: str): """Creates a summary writer logging to ``log_dir``. Args: log_dir: the directory to save the logs. """ self.writer = tf.summary.FileWriter(log_dir)
[docs] def scalar_summary(self, tag, value, step): summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) self.writer.add_summary(summary, step)
[docs] def image_summary(self, tag, images, step): img_summaries = [] for i, img in enumerate(images): # Write the image to a string s = BytesIO() Image.fromarray(img).save(s, format="png") # Create an Image object img_sum = tf.Summary.Image( encoded_image_string=s.getvalue(), height=img.shape[0], width=img.shape[1] ) # Create a Summary value img_summaries.append( tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum) ) # Create and write Summary summary = tf.Summary(value=img_summaries) self.writer.add_summary(summary, step)
[docs] def histo_summary(self, tag, values, step, bins=1000): # Create a histogram using numpy counts, bin_edges = np.histogram(values, bins=bins) # Fill the fields of the histogram proto hist = tf.HistogramProto() hist.min = float(np.min(values)) hist.max = float(np.max(values)) hist.num = int(np.prod(values.shape)) hist.sum = float(np.sum(values)) hist.sum_squares = float(np.sum(values ** 2)) # Drop the start of the first bin bin_edges = bin_edges[1:] # Add bin edges and counts for edge in bin_edges: hist.bucket_limit.append(edge) for c in counts: hist.bucket.append(c) # Create and write Summary summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) self.writer.add_summary(summary, step)
[docs] def flush(self): self.writer.flush()
[docs] class TBGroupMeters(GroupMeters): """A group of meters that can be updated and logged to tensorboard."""
[docs] def __init__(self, tb_logger): super().__init__() self._tb_logger = tb_logger
[docs] def update(self, updates=None, value=None, n=1, prefix=None, **kwargs): """Update the meters. Example: >>> meters.update(key, value) >>> meters.update({key1: value1, key2: value2}) >>> meters.update(key1=value1, key2=value2) """ if updates is None: updates = {} if updates is not None and value is not None: updates = {updates: value} updates.update(kwargs) for k, v in updates.items(): k = k if prefix is None else prefix + k self._meters[k].update(v, n=n) self._tb_logger.scalar_summary(k, self._meters[k].val, self._meters[k].tot_count)
[docs] def flush(self): self._tb_logger.flush()