Source code for jactorch.train.tb
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : tb.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 05/09/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import numpy as np
from PIL import Image
try:
from StringIO import StringIO as BytesIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
import tensorflow
if str(tensorflow.__version__) >= '1.14.0':
import tensorflow.compat.v1 as tf
if hasattr(tf, 'disable_eager_execution'):
tf.disable_eager_execution()
else:
import tensorflow as tf
from jacinle.utils.meter import GroupMeters
[docs]
class TBLogger(object):
"""A simple wrapper for tensorboard.
Source: https://raw.githubusercontent.com/SherlockLiao/pytorch-beginner/
"""
[docs]
def __init__(self, log_dir: str):
"""Creates a summary writer logging to ``log_dir``.
Args:
log_dir: the directory to save the logs.
"""
self.writer = tf.summary.FileWriter(log_dir)
[docs]
def scalar_summary(self, tag, value, step):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
[docs]
def image_summary(self, tag, images, step):
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
s = BytesIO()
Image.fromarray(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(
encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1]
)
# Create a Summary value
img_summaries.append(
tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)
)
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
[docs]
def histo_summary(self, tag, values, step, bins=1000):
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values ** 2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
[docs]
def flush(self):
self.writer.flush()
[docs]
class TBGroupMeters(GroupMeters):
"""A group of meters that can be updated and logged to tensorboard."""
[docs]
def __init__(self, tb_logger):
super().__init__()
self._tb_logger = tb_logger
[docs]
def update(self, updates=None, value=None, n=1, prefix=None, **kwargs):
"""Update the meters.
Example:
>>> meters.update(key, value)
>>> meters.update({key1: value1, key2: value2})
>>> meters.update(key1=value1, key2=value2)
"""
if updates is None:
updates = {}
if updates is not None and value is not None:
updates = {updates: value}
updates.update(kwargs)
for k, v in updates.items():
k = k if prefix is None else prefix + k
self._meters[k].update(v, n=n)
self._tb_logger.scalar_summary(k, self._meters[k].val, self._meters[k].tot_count)
[docs]
def flush(self):
self._tb_logger.flush()