Source code for jaclearn.embedding.visualize_tb
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : visualize_tb.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/22/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import os.path as osp
import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
__all__ = ['visualize_word_embedding_tb']
[docs]
def visualize_word_embedding_tb(emb, log_dir):
# https://stackoverflow.com/questions/41258391/tensorboard-embedding-example#answer-42676076
if isinstance(emb, tuple): # embedding, word2idx
words = sorted(emb[1].keys(), key=lambda x: emb[1][x])
embedding = np.array(emb[0])
else:
words = emb.keys()
embedding = np.stack([emb[key] for key in words])
# setup a TensorFlow session
tf.reset_default_graph()
embedding_var = tf.Variable(embedding, name='embedding')
with open(osp.join(log_dir, 'metadata.tsv'), 'w') as f:
for w in words:
f.write(w + '\n')
# create a TensorFlow summary writer
summary_writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())
config = projector.ProjectorConfig()
embedding_conf = config.embeddings.add()
embedding_conf.tensor_name = embedding_var.name
embedding_conf.metadata_path = 'metadata.tsv'
projector.visualize_embeddings(summary_writer, config)
# save the model
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(embedding_var.initializer)
saver.save(sess, osp.join(log_dir, "model.ckpt"))