dawn-bench-models/tensorflow/SQuAD/tree/trainer.py

37 lines
1.3 KiB
Python
Raw Permalink Normal View History

2017-08-17 12:43:17 -06:00
import tensorflow as tf
from tree.model import Model
class Trainer(object):
def __init__(self, config, model):
assert isinstance(model, Model)
self.config = config
self.model = model
self.opt = tf.train.AdagradOptimizer(config.init_lr)
self.loss = model.get_loss()
self.var_list = model.get_var_list()
self.global_step = model.get_global_step()
self.ema_op = model.ema_op
self.summary = model.summary
self.grads = self.opt.compute_gradients(self.loss, var_list=self.var_list)
opt_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
# Define train op
with tf.control_dependencies([opt_op]):
self.train_op = tf.group(self.ema_op)
def get_train_op(self):
return self.train_op
def step(self, sess, batch, get_summary=False):
assert isinstance(sess, tf.Session)
feed_dict = self.model.get_feed_dict(batch, True)
if get_summary:
loss, summary, train_op = \
sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict)
else:
loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict)
summary = None
return loss, summary, train_op