dawn-bench-models/tensorflow/SQuAD/tree/graph_handler.py

55 lines
1.9 KiB
Python
Raw Normal View History

2017-08-17 12:43:17 -06:00
import json
from json import encoder
import os
import tensorflow as tf
from tree.evaluator import Evaluation
from my.utils import short_floats
class GraphHandler(object):
def __init__(self, config):
self.config = config
self.saver = tf.train.Saver()
self.writer = None
self.save_path = os.path.join(config.save_dir, config.model_name)
def initialize(self, sess):
if self.config.load:
self._load(sess)
else:
sess.run(tf.global_variables_initializer())
if self.config.mode == 'train':
self.writer = tf.summary.FileWriter(self.config.log_dir, graph=tf.get_default_graph())
def save(self, sess, global_step=None):
self.saver.save(sess, self.save_path, global_step=global_step)
def _load(self, sess):
config = self.config
if config.load_step > 0:
save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
else:
save_dir = config.save_dir
checkpoint = tf.train.get_checkpoint_state(save_dir)
assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
save_path = checkpoint.model_checkpoint_path
print("Loading saved model from {}".format(save_path))
self.saver.restore(sess, save_path)
def add_summary(self, summary, global_step):
self.writer.add_summary(summary, global_step)
def add_summaries(self, summaries, global_step):
for summary in summaries:
self.add_summary(summary, global_step)
def dump_eval(self, e, precision=2):
assert isinstance(e, Evaluation)
path = os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6)))
with open(path, 'w') as fh:
json.dump(short_floats(e.dict, precision), fh)