140 lines
No EOL
4.6 KiB
Python
140 lines
No EOL
4.6 KiB
Python
import shutil
|
|
from collections import OrderedDict
|
|
import http.server
|
|
import socketserver
|
|
import argparse
|
|
import json
|
|
import os
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import pickle
|
|
import gzip
|
|
|
|
from jinja2 import Environment, FileSystemLoader
|
|
|
|
from squad.utils import get_best_span, get_best_span_wy
|
|
|
|
|
|
def bool_(string):
|
|
if string == 'True':
|
|
return True
|
|
elif string == 'False':
|
|
return False
|
|
else:
|
|
raise Exception()
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_name", type=str, default='basic')
|
|
parser.add_argument("--data_type", type=str, default='dev')
|
|
parser.add_argument("--step", type=int, default=5000)
|
|
parser.add_argument("--template_name", type=str, default="visualizer.html")
|
|
parser.add_argument("--num_per_page", type=int, default=100)
|
|
parser.add_argument("--data_dir", type=str, default="data/squad")
|
|
parser.add_argument("--port", type=int, default=8000)
|
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
parser.add_argument("--open", type=str, default='False')
|
|
parser.add_argument("--run_id", type=str, default="0")
|
|
parser.add_argument("-w", "--wy", action='store_true')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def _decode(decoder, sent):
|
|
return " ".join(decoder[idx] for idx in sent)
|
|
|
|
|
|
def accuracy2_visualizer(args):
|
|
model_name = args.model_name
|
|
data_type = args.data_type
|
|
num_per_page = args.num_per_page
|
|
data_dir = args.data_dir
|
|
run_id = args.run_id.zfill(2)
|
|
step = args.step
|
|
|
|
eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.pklz".format(data_type, str(step).zfill(6)))
|
|
print("loading {}".format(eval_path))
|
|
eval_ = pickle.load(gzip.open(eval_path, 'r'))
|
|
|
|
_id = 0
|
|
html_dir = "/tmp/list_results%d" % _id
|
|
while os.path.exists(html_dir):
|
|
_id += 1
|
|
html_dir = "/tmp/list_results%d" % _id
|
|
|
|
if os.path.exists(html_dir):
|
|
shutil.rmtree(html_dir)
|
|
os.mkdir(html_dir)
|
|
|
|
cur_dir = os.path.dirname(os.path.realpath(__file__))
|
|
templates_dir = os.path.join(cur_dir, 'templates')
|
|
env = Environment(loader=FileSystemLoader(templates_dir))
|
|
env.globals.update(zip=zip, reversed=reversed)
|
|
template = env.get_template(args.template_name)
|
|
|
|
data_path = os.path.join(data_dir, "data_{}.json".format(data_type))
|
|
shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type))
|
|
print("loading {}".format(data_path))
|
|
data = json.load(open(data_path, 'r'))
|
|
print("loading {}".format(shared_path))
|
|
shared = json.load(open(shared_path, 'r'))
|
|
|
|
rows = []
|
|
for i, (idx, yi, ypi, yp2i, wypi) in tqdm(enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp', 'yp2', 'wyp')])), total=len(eval_['idxs'])):
|
|
id_, q, rx, answers = (data[key][idx] for key in ('ids', 'q', '*x', 'answerss'))
|
|
x = shared['x'][rx[0]][rx[1]]
|
|
ques = [" ".join(q)]
|
|
para = [[word for word in sent] for sent in x]
|
|
span, score = get_best_span_wy(wypi, 0.5) if args.wy else get_best_span(ypi, yp2i)
|
|
ap = get_segment(para, span)
|
|
# score = "{:.3f}".format(ypi[span[0][0]][span[0][1]] * yp2i[span[1][0]][span[1][1]-1])
|
|
|
|
row = {
|
|
'id': id_,
|
|
'title': "Hello world!",
|
|
'ques': ques,
|
|
'para': para,
|
|
'y': yi[0][0],
|
|
'y2': yi[0][1],
|
|
'yp': wypi if args.wy else ypi,
|
|
'yp2': wypi if args.wy else yp2i,
|
|
'a': answers,
|
|
'ap': ap,
|
|
'score': score
|
|
}
|
|
rows.append(row)
|
|
|
|
if i % num_per_page == 0:
|
|
html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8))
|
|
|
|
if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']):
|
|
var_dict = {'title': "Accuracy Visualization",
|
|
'rows': rows
|
|
}
|
|
with open(html_path, "wb") as f:
|
|
f.write(template.render(**var_dict).encode('UTF-8'))
|
|
rows = []
|
|
|
|
os.chdir(html_dir)
|
|
port = args.port
|
|
host = args.host
|
|
# Overriding to suppress log message
|
|
class MyHandler(http.server.SimpleHTTPRequestHandler):
|
|
def log_message(self, format, *args):
|
|
pass
|
|
handler = MyHandler
|
|
httpd = socketserver.TCPServer((host, port), handler)
|
|
if args.open == 'True':
|
|
os.system("open http://%s:%d" % (args.host, args.port))
|
|
print("serving at %s:%d" % (host, port))
|
|
httpd.serve_forever()
|
|
|
|
|
|
def get_segment(para, span):
|
|
return " ".join(para[span[0][0]][span[0][1]:span[1][1]])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ARGS = get_args()
|
|
accuracy2_visualizer(ARGS) |