95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
|
""" Official evaluation script for v1.1 of the SQuAD dataset. """
|
||
|
from __future__ import print_function
|
||
|
from collections import Counter
|
||
|
import string
|
||
|
import re
|
||
|
import argparse
|
||
|
import json
|
||
|
import sys
|
||
|
|
||
|
|
||
|
def normalize_answer(s):
|
||
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||
|
def remove_articles(text):
|
||
|
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||
|
|
||
|
def white_space_fix(text):
|
||
|
return ' '.join(text.split())
|
||
|
|
||
|
def remove_punc(text):
|
||
|
exclude = set(string.punctuation)
|
||
|
return ''.join(ch for ch in text if ch not in exclude)
|
||
|
|
||
|
def lower(text):
|
||
|
return text.lower()
|
||
|
|
||
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||
|
|
||
|
|
||
|
def f1_score(prediction, ground_truth):
|
||
|
prediction_tokens = normalize_answer(prediction).split()
|
||
|
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||
|
num_same = sum(common.values())
|
||
|
if num_same == 0:
|
||
|
return 0
|
||
|
precision = 1.0 * num_same / len(prediction_tokens)
|
||
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||
|
return f1
|
||
|
|
||
|
|
||
|
def exact_match_score(prediction, ground_truth):
|
||
|
return (normalize_answer(prediction) == normalize_answer(ground_truth))
|
||
|
|
||
|
|
||
|
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||
|
scores_for_ground_truths = []
|
||
|
for ground_truth in ground_truths:
|
||
|
score = metric_fn(prediction, ground_truth)
|
||
|
scores_for_ground_truths.append(score)
|
||
|
return max(scores_for_ground_truths)
|
||
|
|
||
|
|
||
|
def evaluate(dataset, predictions):
|
||
|
f1 = exact_match = total = 0
|
||
|
for article in dataset:
|
||
|
for paragraph in article['paragraphs']:
|
||
|
for qa in paragraph['qas']:
|
||
|
total += 1
|
||
|
if qa['id'] not in predictions:
|
||
|
message = 'Unanswered question ' + qa['id'] + \
|
||
|
' will receive score 0.'
|
||
|
print(message, file=sys.stderr)
|
||
|
continue
|
||
|
ground_truths = list(map(lambda x: x['text'], qa['answers']))
|
||
|
prediction = predictions[qa['id']]
|
||
|
exact_match += metric_max_over_ground_truths(
|
||
|
exact_match_score, prediction, ground_truths)
|
||
|
f1 += metric_max_over_ground_truths(
|
||
|
f1_score, prediction, ground_truths)
|
||
|
|
||
|
exact_match = 100.0 * exact_match / total
|
||
|
f1 = 100.0 * f1 / total
|
||
|
|
||
|
return {'exact_match': exact_match, 'f1': f1}
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
expected_version = '1.1'
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description='Evaluation for SQuAD ' + expected_version)
|
||
|
parser.add_argument('dataset_file', help='Dataset file')
|
||
|
parser.add_argument('prediction_file', help='Prediction File')
|
||
|
args = parser.parse_args()
|
||
|
with open(args.dataset_file) as dataset_file:
|
||
|
dataset_json = json.load(dataset_file)
|
||
|
if (dataset_json['version'] != expected_version):
|
||
|
print('Evaluation expects v-' + expected_version +
|
||
|
', but got dataset with v-' + dataset_json['version'],
|
||
|
file=sys.stderr)
|
||
|
dataset = dataset_json['data']
|
||
|
with open(args.prediction_file) as prediction_file:
|
||
|
predictions = json.load(prediction_file)
|
||
|
print(json.dumps(evaluate(dataset, predictions)))
|