dawn-bench-models/tensorflow/SQuAD/tree/main.py

188 lines
6.5 KiB
Python
Raw Permalink Normal View History

2017-08-17 12:43:17 -06:00
import argparse
import json
import math
import os
import shutil
from pprint import pprint
import tensorflow as tf
from tqdm import tqdm
import numpy as np
from tree.evaluator import AccuracyEvaluator2, Evaluator
from tree.graph_handler import GraphHandler
from tree.model import Model
from tree.trainer import Trainer
from tree.read_data import load_metadata, read_data, get_squad_data_filter, update_config
def main(config):
set_dirs(config)
if config.mode == 'train':
_train(config)
elif config.mode == 'test':
_test(config)
elif config.mode == 'forward':
_forward(config)
else:
raise ValueError("invalid value for 'mode': {}".format(config.mode))
def _config_draft(config):
if config.draft:
config.num_steps = 10
config.eval_period = 10
config.log_period = 1
config.save_period = 10
config.eval_num_batches = 1
def _train(config):
# load_metadata(config, 'train') # this updates the config file according to metadata file
data_filter = get_squad_data_filter(config)
train_data = read_data(config, 'train', config.load, data_filter=data_filter)
dev_data = read_data(config, 'dev', True, data_filter=data_filter)
update_config(config, [train_data, dev_data])
_config_draft(config)
word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec']
word2idx_dict = train_data.shared['word2idx']
idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict}
print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict
else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size))
for idx in range(config.word_vocab_size)])
config.emb_mat = emb_mat
# construct model graph and variables (using default graph)
pprint(config.__flags, indent=2)
model = Model(config)
trainer = Trainer(config, model)
evaluator = AccuracyEvaluator2(config, model)
graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
# Variables
sess = tf.Session()
graph_handler.initialize(sess)
# begin training
num_steps = config.num_steps or int(config.num_epochs * train_data.num_examples / config.batch_size)
max_acc = 0
noupdate_count = 0
global_step = 0
for _, batch in tqdm(train_data.get_batches(config.batch_size, num_batches=num_steps, shuffle=True), total=num_steps):
global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step
get_summary = global_step % config.log_period == 0
loss, summary, train_op = trainer.step(sess, batch, get_summary=get_summary)
if get_summary:
graph_handler.add_summary(summary, global_step)
# Occasional evaluation and saving
if global_step % config.save_period == 0:
graph_handler.save(sess, global_step=global_step)
if global_step % config.eval_period == 0:
num_batches = math.ceil(dev_data.num_examples / config.batch_size)
if 0 < config.eval_num_batches < num_batches:
num_batches = config.eval_num_batches
e = evaluator.get_evaluation_from_batches(
sess, tqdm(dev_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
graph_handler.add_summaries(e.summaries, global_step)
if e.acc > max_acc:
max_acc = e.acc
noupdate_count = 0
else:
noupdate_count += 1
if noupdate_count == config.early_stop:
break
if config.dump_eval:
graph_handler.dump_eval(e)
if global_step % config.save_period != 0:
graph_handler.save(sess, global_step=global_step)
def _test(config):
test_data = read_data(config, 'test', True)
update_config(config, [test_data])
_config_draft(config)
pprint(config.__flags, indent=2)
model = Model(config)
evaluator = AccuracyEvaluator2(config, model)
graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
sess = tf.Session()
graph_handler.initialize(sess)
num_batches = math.ceil(test_data.num_examples / config.batch_size)
if 0 < config.eval_num_batches < num_batches:
num_batches = config.eval_num_batches
e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
print(e)
if config.dump_eval:
graph_handler.dump_eval(e)
def _forward(config):
forward_data = read_data(config, 'forward', True)
_config_draft(config)
pprint(config.__flag, indent=2)
model = Model(config)
evaluator = Evaluator(config, model)
graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
sess = tf.Session()
graph_handler.initialize(sess)
num_batches = math.ceil(forward_data.num_examples / config.batch_size)
if 0 < config.eval_num_batches < num_batches:
num_batches = config.eval_num_batches
e = evaluator.get_evaluation_from_batches(sess, tqdm(forward_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
print(e)
if config.dump_eval:
graph_handler.dump_eval(e)
def set_dirs(config):
# create directories
if not config.load and os.path.exists(config.out_dir):
shutil.rmtree(config.out_dir)
config.save_dir = os.path.join(config.out_dir, "save")
config.log_dir = os.path.join(config.out_dir, "log")
config.eval_dir = os.path.join(config.out_dir, "eval")
if not os.path.exists(config.out_dir):
os.makedirs(config.out_dir)
if not os.path.exists(config.save_dir):
os.mkdir(config.save_dir)
if not os.path.exists(config.log_dir):
os.mkdir(config.eval_dir)
def _get_args():
parser = argparse.ArgumentParser()
parser.add_argument("config_path")
return parser.parse_args()
class Config(object):
def __init__(self, **entries):
self.__dict__.update(entries)
def _run():
args = _get_args()
with open(args.config_path, 'r') as fh:
config = Config(**json.load(fh))
main(config)
if __name__ == "__main__":
_run()