233 lines
9.2 KiB
Python
233 lines
9.2 KiB
Python
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import shutil
|
|
from pprint import pprint
|
|
|
|
import tensorflow as tf
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
from basic.evaluator import ForwardEvaluator, MultiGPUF1Evaluator
|
|
from basic.graph_handler import GraphHandler
|
|
from basic.model import get_multi_gpu_models
|
|
from basic.trainer import MultiGPUTrainer
|
|
from basic.read_data import read_data, get_squad_data_filter, update_config
|
|
from my.tensorflow import get_num_params
|
|
|
|
|
|
def main(config):
|
|
set_dirs(config)
|
|
with tf.device(config.device):
|
|
if config.mode == 'train':
|
|
_train(config)
|
|
elif config.mode == 'test':
|
|
_test(config)
|
|
elif config.mode == 'forward':
|
|
_forward(config)
|
|
else:
|
|
raise ValueError("invalid value for 'mode': {}".format(config.mode))
|
|
|
|
|
|
def set_dirs(config):
|
|
# create directories
|
|
assert config.load or config.mode == 'train', "config.load must be True if not training"
|
|
if not config.load and os.path.exists(config.out_dir):
|
|
shutil.rmtree(config.out_dir)
|
|
|
|
config.save_dir = os.path.join(config.out_dir, "save")
|
|
config.log_dir = os.path.join(config.out_dir, "log")
|
|
config.eval_dir = os.path.join(config.out_dir, "eval")
|
|
config.answer_dir = os.path.join(config.out_dir, "answer")
|
|
if not os.path.exists(config.out_dir):
|
|
os.makedirs(config.out_dir)
|
|
if not os.path.exists(config.save_dir):
|
|
os.mkdir(config.save_dir)
|
|
if not os.path.exists(config.log_dir):
|
|
os.mkdir(config.log_dir)
|
|
if not os.path.exists(config.answer_dir):
|
|
os.mkdir(config.answer_dir)
|
|
if not os.path.exists(config.eval_dir):
|
|
os.mkdir(config.eval_dir)
|
|
|
|
|
|
def _config_debug(config):
|
|
if config.debug:
|
|
config.num_steps = 2
|
|
config.eval_period = 1
|
|
config.log_period = 1
|
|
config.save_period = 1
|
|
config.val_num_batches = 2
|
|
config.test_num_batches = 2
|
|
|
|
|
|
def _train(config):
|
|
data_filter = get_squad_data_filter(config)
|
|
train_data = read_data(config, 'train', config.load, data_filter=data_filter)
|
|
dev_data = read_data(config, 'dev', True, data_filter=data_filter)
|
|
update_config(config, [train_data, dev_data])
|
|
|
|
_config_debug(config)
|
|
|
|
word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec']
|
|
word2idx_dict = train_data.shared['word2idx']
|
|
idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict}
|
|
emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict
|
|
else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size))
|
|
for idx in range(config.word_vocab_size)])
|
|
config.emb_mat = emb_mat
|
|
|
|
# construct model graph and variables (using default graph)
|
|
pprint(config.__flags, indent=2)
|
|
models = get_multi_gpu_models(config)
|
|
model = models[0]
|
|
print("num params: {}".format(get_num_params()))
|
|
trainer = MultiGPUTrainer(config, models)
|
|
evaluator = MultiGPUF1Evaluator(config, models, tensor_dict=model.tensor_dict if config.vis else None)
|
|
graph_handler = GraphHandler(config, model) # controls all tensors and variables in the graph, including loading /saving
|
|
|
|
# Variables
|
|
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
|
graph_handler.initialize(sess)
|
|
|
|
# Begin training
|
|
num_steps = config.num_steps or int(math.ceil(train_data.num_examples / (config.batch_size * config.num_gpus))) * config.num_epochs
|
|
global_step = 0
|
|
for batches in tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus,
|
|
num_steps=num_steps, shuffle=True, cluster=config.cluster), total=num_steps):
|
|
global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step
|
|
get_summary = global_step % config.log_period == 0
|
|
loss, summary, train_op = trainer.step(sess, batches, get_summary=get_summary)
|
|
if get_summary:
|
|
graph_handler.add_summary(summary, global_step)
|
|
|
|
# occasional saving
|
|
if global_step % config.save_period == 0:
|
|
graph_handler.save(sess, global_step=global_step)
|
|
|
|
if not config.eval:
|
|
continue
|
|
# Occasional evaluation
|
|
if global_step % config.eval_period == 0:
|
|
num_steps = math.ceil(dev_data.num_examples / (config.batch_size * config.num_gpus))
|
|
if 0 < config.val_num_batches < num_steps:
|
|
num_steps = config.val_num_batches
|
|
e_train = evaluator.get_evaluation_from_batches(
|
|
sess, tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps)
|
|
)
|
|
graph_handler.add_summaries(e_train.summaries, global_step)
|
|
e_dev = evaluator.get_evaluation_from_batches(
|
|
sess, tqdm(dev_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps))
|
|
graph_handler.add_summaries(e_dev.summaries, global_step)
|
|
|
|
if config.dump_eval:
|
|
graph_handler.dump_eval(e_dev)
|
|
if config.dump_answer:
|
|
graph_handler.dump_answer(e_dev)
|
|
if global_step % config.save_period != 0:
|
|
graph_handler.save(sess, global_step=global_step)
|
|
|
|
|
|
def _test(config):
|
|
test_data = read_data(config, 'test', True)
|
|
update_config(config, [test_data])
|
|
|
|
_config_debug(config)
|
|
|
|
if config.use_glove_for_unk:
|
|
word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec']
|
|
new_word2idx_dict = test_data.shared['new_word2idx']
|
|
idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
|
|
new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
|
|
config.new_emb_mat = new_emb_mat
|
|
|
|
pprint(config.__flags, indent=2)
|
|
models = get_multi_gpu_models(config)
|
|
model = models[0]
|
|
evaluator = MultiGPUF1Evaluator(config, models, tensor_dict=models[0].tensor_dict if config.vis else None)
|
|
graph_handler = GraphHandler(config, model)
|
|
|
|
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
|
graph_handler.initialize(sess)
|
|
num_steps = math.ceil(test_data.num_examples / (config.batch_size * config.num_gpus))
|
|
if 0 < config.test_num_batches < num_steps:
|
|
num_steps = config.test_num_batches
|
|
|
|
e = None
|
|
for multi_batch in tqdm(test_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps, cluster=config.cluster), total=num_steps):
|
|
ei = evaluator.get_evaluation(sess, multi_batch)
|
|
e = ei if e is None else e + ei
|
|
if config.vis:
|
|
eval_subdir = os.path.join(config.eval_dir, "{}-{}".format(ei.data_type, str(ei.global_step).zfill(6)))
|
|
if not os.path.exists(eval_subdir):
|
|
os.mkdir(eval_subdir)
|
|
path = os.path.join(eval_subdir, str(ei.idxs[0]).zfill(8))
|
|
graph_handler.dump_eval(ei, path=path)
|
|
|
|
print(e)
|
|
if config.dump_answer:
|
|
print("dumping answer ...")
|
|
graph_handler.dump_answer(e)
|
|
if config.dump_eval:
|
|
print("dumping eval ...")
|
|
graph_handler.dump_eval(e)
|
|
|
|
|
|
def _forward(config):
|
|
assert config.load
|
|
test_data = read_data(config, config.forward_name, True)
|
|
update_config(config, [test_data])
|
|
|
|
_config_debug(config)
|
|
|
|
if config.use_glove_for_unk:
|
|
word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec']
|
|
new_word2idx_dict = test_data.shared['new_word2idx']
|
|
idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
|
|
new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
|
|
config.new_emb_mat = new_emb_mat
|
|
|
|
pprint(config.__flags, indent=2)
|
|
models = get_multi_gpu_models(config)
|
|
model = models[0]
|
|
print("num params: {}".format(get_num_params()))
|
|
evaluator = ForwardEvaluator(config, model)
|
|
graph_handler = GraphHandler(config, model) # controls all tensors and variables in the graph, including loading /saving
|
|
|
|
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
|
graph_handler.initialize(sess)
|
|
|
|
num_batches = math.ceil(test_data.num_examples / config.batch_size)
|
|
if 0 < config.test_num_batches < num_batches:
|
|
num_batches = config.test_num_batches
|
|
e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
|
|
print(e)
|
|
if config.dump_answer:
|
|
print("dumping answer ...")
|
|
graph_handler.dump_answer(e, path=config.answer_path)
|
|
if config.dump_eval:
|
|
print("dumping eval ...")
|
|
graph_handler.dump_eval(e, path=config.eval_path)
|
|
|
|
|
|
def _get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("config_path")
|
|
return parser.parse_args()
|
|
|
|
|
|
class Config(object):
|
|
def __init__(self, **entries):
|
|
self.__dict__.update(entries)
|
|
|
|
|
|
def _run():
|
|
args = _get_args()
|
|
with open(args.config_path, 'r') as fh:
|
|
config = Config(**json.load(fh))
|
|
main(config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_run()
|