import numpy as np from collections import Counter import string import re import argparse import os import json import nltk from matplotlib_venn import venn2 from matplotlib import pyplot as plt class Question: def __init__(self, id, question_text, ground_truth, model_names): self.id = id self.question_text = self.normalize_answer(question_text) self.question_head_ngram = [] self.question_tokens = nltk.word_tokenize(self.question_text) for nc in range(3): self.question_head_ngram.append(' '.join(self.question_tokens[0:nc])) self.ground_truth = ground_truth self.model_names = model_names self.em = np.zeros(2) self.f1 = np.zeros(2) self.answer_text = [] def add_answers(self, answer_model_1, answer_model_2): self.answer_text.append(answer_model_1) self.answer_text.append(answer_model_2) self.eval() def eval(self): for model_count in range(2): self.em[model_count] = self.metric_max_over_ground_truths(self.exact_match_score, self.answer_text[model_count], self.ground_truth) self.f1[model_count] = self.metric_max_over_ground_truths(self.f1_score, self.answer_text[model_count], self.ground_truth) def normalize_answer(self, s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def f1_score(self, prediction, ground_truth): prediction_tokens = self.normalize_answer(prediction).split() ground_truth_tokens = self.normalize_answer(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1 def exact_match_score(self, prediction, ground_truth): return (self.normalize_answer(prediction) == self.normalize_answer(ground_truth)) def metric_max_over_ground_truths(self, metric_fn, prediction, ground_truths): scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths) def safe_dict_access(in_dict, in_key, default_string='some junk string'): if in_key in in_dict: return in_dict[in_key] else: return default_string def aggregate_metrics(questions): total = len(questions) exact_match = np.zeros(2) f1_scores = np.zeros(2) for mc in range(2): exact_match[mc] = 100 * np.sum(np.array([questions[x].em[mc] for x in questions])) / total f1_scores[mc] = 100 * np.sum(np.array([questions[x].f1[mc] for x in questions])) / total model_names = questions[list(questions.keys())[0]].model_names print('\nAggregate Scores:') for model_count in range(2): print('Model {0} EM = {1:.2f}'.format(model_names[model_count], exact_match[model_count])) print('Model {0} F1 = {1:.2f}'.format(model_names[model_count], f1_scores[model_count])) def venn_diagram(questions, output_dir): em_model1_ids = [x for x in questions if questions[x].em[0] == 1] em_model2_ids = [x for x in questions if questions[x].em[1] == 1] model_names = questions[list(questions.keys())[0]].model_names print('\nVenn diagram') correct_model1 = em_model1_ids correct_model2 = em_model2_ids correct_model1_and_model2 = list(set(em_model1_ids).intersection(set(em_model2_ids))) correct_model1_and_not_model2 = list(set(em_model1_ids) - set(em_model2_ids)) correct_model2_and_not_model1 = list(set(em_model2_ids) - set(em_model1_ids)) print('{0} answers correctly = {1}'.format(model_names[0], len(correct_model1))) print('{0} answers correctly = {1}'.format(model_names[1], len(correct_model2))) print('Both answer correctly = {1}'.format(model_names[0], len(correct_model1_and_model2))) print('{0} correct & {1} incorrect = {2}'.format(model_names[0], model_names[1], len(correct_model1_and_not_model2))) print('{0} correct & {1} incorrect = {2}'.format(model_names[1], model_names[0], len(correct_model2_and_not_model1))) plt.clf() venn_diagram_plot = venn2( subsets=(len(correct_model1_and_not_model2), len(correct_model2_and_not_model1), len(correct_model1_and_model2)), set_labels=('{0} correct'.format(model_names[0]), '{0} correct'.format(model_names[1]), 'Both correct'), set_colors=('r', 'b'), alpha=0.3, normalize_to=1 ) plt.savefig(os.path.join(output_dir, 'venn_diagram.png')) plt.close() return correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1 def get_head_ngrams(questions, num_grams): head_ngrams = [] for question in questions.values(): head_ngrams.append(question.question_head_ngram[num_grams]) return head_ngrams def get_head_ngram_frequencies(questions, head_ngrams, num_grams): head_ngram_frequencies = {} for current_ngram in head_ngrams: head_ngram_frequencies[current_ngram] = 0 for question in questions.values(): head_ngram_frequencies[question.question_head_ngram[num_grams]] += 1 return head_ngram_frequencies def get_head_ngram_statistics(questions, correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1, output_dir, num_grams=2, top_count=25): # Head ngram statistics head_ngrams = get_head_ngrams(questions, num_grams) # Get head_ngram_frequencies (hnf) hnf_all = get_head_ngram_frequencies(questions, head_ngrams, num_grams) hnf_correct_model1 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model1}, head_ngrams, num_grams) hnf_correct_model2 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model2}, head_ngrams, num_grams) hnf_correct_model1_and_model2 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model1_and_model2}, head_ngrams, num_grams) hnf_correct_model1_and_not_model2 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model1_and_not_model2}, head_ngrams, num_grams) hnf_correct_model2_and_not_model1 = get_head_ngram_frequencies({qid: questions[qid] for qid in correct_model2_and_not_model1}, head_ngrams, num_grams) sorted_bigrams_all = sorted(hnf_all.items(), key=lambda x: x[1], reverse=True) top_bigrams = [x[0] for x in sorted_bigrams_all[0:top_count]] counts_total = [hnf_all[x] for x in top_bigrams] counts_model1 = [hnf_correct_model1[x] for x in top_bigrams] counts_model2 = [hnf_correct_model2[x] for x in top_bigrams] counts_model1_and_model2 = [hnf_correct_model1_and_model2[x] for x in top_bigrams] counts_model1_and_not_model2 = [hnf_correct_model1_and_not_model2[x] for x in top_bigrams] counts_model2_and_not_model1 = [hnf_correct_model2_and_not_model1[x] for x in top_bigrams] top_bigrams_with_counts = [] for cc in range(len(top_bigrams)): top_bigrams_with_counts.append('{0} ({1})'.format(top_bigrams[cc], counts_total[cc])) plt.clf() fig, ax = plt.subplots(figsize=(6, 10)) ylocs = list(range(top_count)) counts_model1_percent = 100 * np.array(counts_model1) / np.array(counts_total) plt.barh([top_count - x for x in ylocs], counts_model1_percent, height=0.4, alpha=0.5, color='#EE3224', label=top_bigrams) counts_model2_percent = 100 * np.array(counts_model2) / np.array(counts_total) plt.barh([top_count - x+0.4 for x in ylocs], counts_model2_percent, height=0.4, alpha=0.5, color='#2432EE', label=top_bigrams ) ax.set_yticks([top_count - x + 0.4 for x in ylocs]) ax.set_yticklabels(top_bigrams_with_counts) ax.set_ylim([0.5, top_count+1]) ax.set_xlim([0, 100]) plt.subplots_adjust(left=0.28, right=0.9, top=0.9, bottom=0.1) plt.xlabel('Percentage of questions with correct answers') plt.ylabel('Top N-grams') plt.savefig(os.path.join(output_dir, 'ngram_stats_{0}.png'.format(num_grams))) plt.close() def read_json(filename): with open(filename) as filepoint: data = json.load(filepoint) return data def compare_models(dataset_file, predictions_m1_file, predictions_m2_file, output_dir, name_m1='Model 1', name_m2='Model 2'): dataset = read_json(dataset_file)['data'] predictions_m1 = read_json(predictions_m1_file) predictions_m2 = read_json(predictions_m2_file) # Read in data total = 0 questions = {} for article in dataset: for paragraph in article['paragraphs']: for qa in paragraph['qas']: current_question = Question(id=qa['id'], question_text=qa['question'], ground_truth=list(map(lambda x: x['text'], qa['answers'])), model_names=[name_m1, name_m2]) current_question.add_answers(answer_model_1=safe_dict_access(predictions_m1, qa['id']), answer_model_2=safe_dict_access(predictions_m2, qa['id'])) questions[current_question.id] = current_question total += 1 model_names = questions[list(questions.keys())[0]].model_names print('Read in {0} questions'.format(total)) # Aggregate scores aggregate_metrics(questions) # Venn diagram correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1 = venn_diagram(questions, output_dir=output_dir) # Head Unigram statistics get_head_ngram_statistics(questions, correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1, output_dir, num_grams=1, top_count=10) # Head Bigram statistics get_head_ngram_statistics(questions, correct_model1, correct_model2, correct_model1_and_model2, correct_model1_and_not_model2, correct_model2_and_not_model1, output_dir, num_grams=2, top_count=10) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Compare two QA models') parser.add_argument('-dataset', action='store', dest='dataset', required=True, help='Dataset file') parser.add_argument('-model1', action='store', dest='predictions_m1', required=True, help='Prediction file for model 1') parser.add_argument('-model2', action='store', dest='predictions_m2', required=True, help='Prediction file for model 2') parser.add_argument('-name1', action='store', dest='name_m1', help='Name for model 1') parser.add_argument('-name2', action='store', dest='name_m2', help='Name for model 2') parser.add_argument('-output', action='store', dest='output_dir', help='Output directory for visualizations') results = parser.parse_args() if results.name_m1 is not None and results.name_m2 is not None: compare_models(dataset_file=results.dataset, predictions_m1_file=results.predictions_m1, predictions_m2_file=results.predictions_m2, output_dir=results.output_dir, name_m1=results.name_m1, name_m2=results.name_m2) else: compare_models(dataset_file=results.dataset, predictions_m1_file=results.predictions_m1, predictions_m2_file=results.predictions_m2, output_dir=results.output_dir)