dawn-bench-models/tensorflow/SQuAD/basic/visualizer.py

140 lines
4.6 KiB
Python
Raw Permalink Normal View History

2017-08-17 12:43:17 -06:00
import shutil
from collections import OrderedDict
import http.server
import socketserver
import argparse
import json
import os
import numpy as np
from tqdm import tqdm
import pickle
import gzip
from jinja2 import Environment, FileSystemLoader
from squad.utils import get_best_span, get_best_span_wy
def bool_(string):
if string == 'True':
return True
elif string == 'False':
return False
else:
raise Exception()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default='basic')
parser.add_argument("--data_type", type=str, default='dev')
parser.add_argument("--step", type=int, default=5000)
parser.add_argument("--template_name", type=str, default="visualizer.html")
parser.add_argument("--num_per_page", type=int, default=100)
parser.add_argument("--data_dir", type=str, default="data/squad")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--open", type=str, default='False')
parser.add_argument("--run_id", type=str, default="0")
parser.add_argument("-w", "--wy", action='store_true')
args = parser.parse_args()
return args
def _decode(decoder, sent):
return " ".join(decoder[idx] for idx in sent)
def accuracy2_visualizer(args):
model_name = args.model_name
data_type = args.data_type
num_per_page = args.num_per_page
data_dir = args.data_dir
run_id = args.run_id.zfill(2)
step = args.step
eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.pklz".format(data_type, str(step).zfill(6)))
print("loading {}".format(eval_path))
eval_ = pickle.load(gzip.open(eval_path, 'r'))
_id = 0
html_dir = "/tmp/list_results%d" % _id
while os.path.exists(html_dir):
_id += 1
html_dir = "/tmp/list_results%d" % _id
if os.path.exists(html_dir):
shutil.rmtree(html_dir)
os.mkdir(html_dir)
cur_dir = os.path.dirname(os.path.realpath(__file__))
templates_dir = os.path.join(cur_dir, 'templates')
env = Environment(loader=FileSystemLoader(templates_dir))
env.globals.update(zip=zip, reversed=reversed)
template = env.get_template(args.template_name)
data_path = os.path.join(data_dir, "data_{}.json".format(data_type))
shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type))
print("loading {}".format(data_path))
data = json.load(open(data_path, 'r'))
print("loading {}".format(shared_path))
shared = json.load(open(shared_path, 'r'))
rows = []
for i, (idx, yi, ypi, yp2i, wypi) in tqdm(enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp', 'yp2', 'wyp')])), total=len(eval_['idxs'])):
id_, q, rx, answers = (data[key][idx] for key in ('ids', 'q', '*x', 'answerss'))
x = shared['x'][rx[0]][rx[1]]
ques = [" ".join(q)]
para = [[word for word in sent] for sent in x]
span, score = get_best_span_wy(wypi, 0.5) if args.wy else get_best_span(ypi, yp2i)
ap = get_segment(para, span)
# score = "{:.3f}".format(ypi[span[0][0]][span[0][1]] * yp2i[span[1][0]][span[1][1]-1])
row = {
'id': id_,
'title': "Hello world!",
'ques': ques,
'para': para,
'y': yi[0][0],
'y2': yi[0][1],
'yp': wypi if args.wy else ypi,
'yp2': wypi if args.wy else yp2i,
'a': answers,
'ap': ap,
'score': score
}
rows.append(row)
if i % num_per_page == 0:
html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8))
if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']):
var_dict = {'title': "Accuracy Visualization",
'rows': rows
}
with open(html_path, "wb") as f:
f.write(template.render(**var_dict).encode('UTF-8'))
rows = []
os.chdir(html_dir)
port = args.port
host = args.host
# Overriding to suppress log message
class MyHandler(http.server.SimpleHTTPRequestHandler):
def log_message(self, format, *args):
pass
handler = MyHandler
httpd = socketserver.TCPServer((host, port), handler)
if args.open == 'True':
os.system("open http://%s:%d" % (args.host, args.port))
print("serving at %s:%d" % (host, port))
httpd.serve_forever()
def get_segment(para, span):
return " ".join(para[span[0][0]][span[0][1]:span[1][1]])
if __name__ == "__main__":
ARGS = get_args()
accuracy2_visualizer(ARGS)