dawn-bench-models/tensorflow/SQuAD/visualization/compare_models.py
Deepak Narayanan b7e1e0fa0f First commit
2017-08-17 11:43:17 -07:00

244 lines
12 KiB
Python

import numpy as np
from collections import Counter
import string
import re
import argparse
import os
import json
import nltk
from matplotlib_venn import venn2
from matplotlib import pyplot as plt
class Question:
def __init__(self, id, question_text, ground_truth, model_names):
self.id = id
self.question_text = self.normalize_answer(question_text)
self.question_head_ngram = []
self.question_tokens = nltk.word_tokenize(self.question_text)
for nc in range(3):
self.question_head_ngram.append(' '.join(self.question_tokens[0:nc]))
self.ground_truth = ground_truth
self.model_names = model_names
self.em = np.zeros(2)
self.f1 = np.zeros(2)
self.answer_text = []
def add_answers(self, answer_model_1, answer_model_2):
self.answer_text.append(answer_model_1)
self.answer_text.append(answer_model_2)
self.eval()
def eval(self):
for model_count in range(2):
self.em[model_count] = self.metric_max_over_ground_truths(self.exact_match_score, self.answer_text[model_count], self.ground_truth)
self.f1[model_count] = self.metric_max_over_ground_truths(self.f1_score, self.answer_text[model_count], self.ground_truth)
def normalize_answer(self, s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(self, prediction, ground_truth):
prediction_tokens = self.normalize_answer(prediction).split()
ground_truth_tokens = self.normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(self, prediction, ground_truth):
return (self.normalize_answer(prediction) == self.normalize_answer(ground_truth))
def metric_max_over_ground_truths(self, metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def safe_dict_access(in_dict, in_key, default_string='some junk string'):
if in_key in in_dict:
return in_dict[in_key]
else:
return default_string
def aggregate_metrics(questions):
total = len(questions)
exact_match = np.zeros(2)
f1_scores = np.zeros(2)
for mc in range(2):
exact_match[mc] = 100 * np.sum(np.array([questions[x].em[mc] for x in questions])) / total
f1_scores[mc] = 100 * np.sum(np.array([questions[x].f1[mc] for x in questions])) / total
model_names = questions[list(questions.keys())[0]].model_names
print('\nAggregate Scores:')
for model_count in range(2):
print('Model {0} EM = {1:.2f}'.format(model_names[model_count], exact_match[model_count]))
print('Model {0} F1 = {1:.2f}'.format(model_names[model_count], f1_scores[model_count]))
def venn_diagram(questions, output_dir):
em_model1_ids = [x for x in questions if questions[x].em[0] == 1]
em_model2_ids = [x for x in questions if questions[x].em[1] == 1]
model_names = questions[list(questions.keys())[0]].model_names
print('\nVenn diagram')
correct_model1 = em_model1_ids
correct_model2 = em_model2_ids
correct_model1_and_model2 = list(set(em_model1_ids).intersection(set(em_model2_ids)))
correct_model1_and_not_model2 = list(set(em_model1_ids) - set(em_model2_ids))
correct_model2_and_not_model1 = list(set(em_model2_ids) - set(em_model1_ids))
print('{0} answers correctly = {1}'.format(model_names[0], len(correct_model1)))
print('{0} answers correctly = {1}'.format(model_names[1], len(correct_model2)))
print('Both answer correctly = {1}'.format(model_names[0], len(correct_model1_and_model2)))
print('{0} correct & {1} incorrect = {2}'.format(model_names[0], model_names[1], len(correct_model1_and_not_model2)))
print('{0} correct & {1} incorrect = {2}'.format(model_names[1], model_names[0], len(correct_model2_and_not_model1)))
plt.clf()
venn_diagram_plot = venn2(
subsets=(len(correct_model1_and_not_model2), len(correct_model2_and_not_model1), len(correct_model1_and_model2)),
set_labels=('{0} correct'.format(model_names[0]), '{0} correct'.format(model_names[1]), 'Both correct'),
set_colors=('r', 'b'),
alpha=0.3,
normalize_to=1
)
plt.savefig(os.path.join(output_dir, 'venn_diagram.png'))
plt.close()
return correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1
def get_head_ngrams(questions, num_grams):
head_ngrams = []
for question in questions.values():
head_ngrams.append(question.question_head_ngram[num_grams])
return head_ngrams
def get_head_ngram_frequencies(questions, head_ngrams, num_grams):
head_ngram_frequencies = {}
for current_ngram in head_ngrams:
head_ngram_frequencies[current_ngram] = 0
for question in questions.values():
head_ngram_frequencies[question.question_head_ngram[num_grams]] += 1
return head_ngram_frequencies
def get_head_ngram_statistics(questions, correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1, output_dir, num_grams=2, top_count=25):
# Head ngram statistics
head_ngrams = get_head_ngrams(questions, num_grams)
# Get head_ngram_frequencies (hnf)
hnf_all = get_head_ngram_frequencies(questions, head_ngrams, num_grams)
hnf_correct_model1 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model1}, head_ngrams, num_grams)
hnf_correct_model2 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model2}, head_ngrams, num_grams)
hnf_correct_model1_and_model2 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model1_and_model2}, head_ngrams, num_grams)
hnf_correct_model1_and_not_model2 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model1_and_not_model2}, head_ngrams, num_grams)
hnf_correct_model2_and_not_model1 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model2_and_not_model1}, head_ngrams, num_grams)
sorted_bigrams_all = sorted(hnf_all.items(), key=lambda x: x[1], reverse=True)
top_bigrams = [x[0] for x in sorted_bigrams_all[0:top_count]]
counts_total = [hnf_all[x] for x in top_bigrams]
counts_model1 = [hnf_correct_model1[x] for x in top_bigrams]
counts_model2 = [hnf_correct_model2[x] for x in top_bigrams]
counts_model1_and_model2 = [hnf_correct_model1_and_model2[x] for x in top_bigrams]
counts_model1_and_not_model2 = [hnf_correct_model1_and_not_model2[x] for x in top_bigrams]
counts_model2_and_not_model1 = [hnf_correct_model2_and_not_model1[x] for x in top_bigrams]
top_bigrams_with_counts = []
for cc in range(len(top_bigrams)):
top_bigrams_with_counts.append('{0} ({1})'.format(top_bigrams[cc], counts_total[cc]))
plt.clf()
fig, ax = plt.subplots(figsize=(6, 10))
ylocs = list(range(top_count))
counts_model1_percent = 100 * np.array(counts_model1) / np.array(counts_total)
plt.barh([top_count - x for x in ylocs], counts_model1_percent, height=0.4, alpha=0.5, color='#EE3224', label=top_bigrams)
counts_model2_percent = 100 * np.array(counts_model2) / np.array(counts_total)
plt.barh([top_count - x+0.4 for x in ylocs], counts_model2_percent, height=0.4, alpha=0.5, color='#2432EE', label=top_bigrams )
ax.set_yticks([top_count - x + 0.4 for x in ylocs])
ax.set_yticklabels(top_bigrams_with_counts)
ax.set_ylim([0.5, top_count+1])
ax.set_xlim([0, 100])
plt.subplots_adjust(left=0.28, right=0.9, top=0.9, bottom=0.1)
plt.xlabel('Percentage of questions with correct answers')
plt.ylabel('Top N-grams')
plt.savefig(os.path.join(output_dir, 'ngram_stats_{0}.png'.format(num_grams)))
plt.close()
def read_json(filename):
with open(filename) as filepoint:
data = json.load(filepoint)
return data
def compare_models(dataset_file, predictions_m1_file, predictions_m2_file, output_dir, name_m1='Model 1', name_m2='Model 2'):
dataset = read_json(dataset_file)['data']
predictions_m1 = read_json(predictions_m1_file)
predictions_m2 = read_json(predictions_m2_file)
# Read in data
total = 0
questions = {}
for article in dataset:
for paragraph in article['paragraphs']:
for qa in paragraph['qas']:
current_question = Question(id=qa['id'], question_text=qa['question'], ground_truth=list(map(lambda x: x['text'], qa['answers'])), model_names=[name_m1, name_m2])
current_question.add_answers(answer_model_1=safe_dict_access(predictions_m1, qa['id']), answer_model_2=safe_dict_access(predictions_m2, qa['id']))
questions[current_question.id] = current_question
total += 1
model_names = questions[list(questions.keys())[0]].model_names
print('Read in {0} questions'.format(total))
# Aggregate scores
aggregate_metrics(questions)
# Venn diagram
correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1 = venn_diagram(questions, output_dir=output_dir)
# Head Unigram statistics
get_head_ngram_statistics(questions, correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2,
correct_model2_and_not_model1, output_dir, num_grams=1, top_count=10)
# Head Bigram statistics
get_head_ngram_statistics(questions, correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2,
correct_model2_and_not_model1, output_dir, num_grams=2, top_count=10)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Compare two QA models')
parser.add_argument('-dataset', action='store', dest='dataset', required=True, help='Dataset file')
parser.add_argument('-model1', action='store', dest='predictions_m1', required=True, help='Prediction file for model 1')
parser.add_argument('-model2', action='store', dest='predictions_m2', required=True, help='Prediction file for model 2')
parser.add_argument('-name1', action='store', dest='name_m1', help='Name for model 1')
parser.add_argument('-name2', action='store', dest='name_m2', help='Name for model 2')
parser.add_argument('-output', action='store', dest='output_dir', help='Output directory for visualizations')
results = parser.parse_args()
if results.name_m1 is not None and results.name_m2 is not None:
compare_models(dataset_file=results.dataset, predictions_m1_file=results.predictions_m1, predictions_m2_file=results.predictions_m2, output_dir=results.output_dir, name_m1=results.name_m1, name_m2=results.name_m2)
else:
compare_models(dataset_file=results.dataset, predictions_m1_file=results.predictions_m1, predictions_m2_file=results.predictions_m2, output_dir=results.output_dir)