157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
|
import json
|
||
|
import sys
|
||
|
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from my.corenlp_interface import CoreNLPInterface
|
||
|
|
||
|
in_path = sys.argv[1]
|
||
|
out_path = sys.argv[2]
|
||
|
url = sys.argv[3]
|
||
|
port = int(sys.argv[4])
|
||
|
data = json.load(open(in_path, 'r'))
|
||
|
|
||
|
h = CoreNLPInterface(url, port)
|
||
|
|
||
|
|
||
|
def find_all(a_str, sub):
|
||
|
start = 0
|
||
|
while True:
|
||
|
start = a_str.find(sub, start)
|
||
|
if start == -1: return
|
||
|
yield start
|
||
|
start += len(sub) # use start += 1 to find overlapping matches
|
||
|
|
||
|
|
||
|
def to_hex(s):
|
||
|
return " ".join(map(hex, map(ord, s)))
|
||
|
|
||
|
|
||
|
def handle_nobreak(cand, text):
|
||
|
if cand == text:
|
||
|
return cand
|
||
|
if cand.replace(u'\u00A0', ' ') == text:
|
||
|
return cand
|
||
|
elif cand == text.replace(u'\u00A0', ' '):
|
||
|
return text
|
||
|
raise Exception("{} '{}' {} '{}'".format(cand, to_hex(cand), text, to_hex(text)))
|
||
|
|
||
|
|
||
|
# resolving unicode complication
|
||
|
|
||
|
wrong_loc_count = 0
|
||
|
loc_diffs = []
|
||
|
|
||
|
for article in data['data']:
|
||
|
for para in article['paragraphs']:
|
||
|
para['context'] = para['context'].replace(u'\u000A', '')
|
||
|
para['context'] = para['context'].replace(u'\u00A0', ' ')
|
||
|
context = para['context']
|
||
|
for qa in para['qas']:
|
||
|
for answer in qa['answers']:
|
||
|
answer['text'] = answer['text'].replace(u'\u00A0', ' ')
|
||
|
text = answer['text']
|
||
|
answer_start = answer['answer_start']
|
||
|
if context[answer_start:answer_start + len(text)] == text:
|
||
|
if text.lstrip() == text:
|
||
|
pass
|
||
|
else:
|
||
|
answer_start += len(text) - len(text.lstrip())
|
||
|
answer['answer_start'] = answer_start
|
||
|
text = text.lstrip()
|
||
|
answer['text'] = text
|
||
|
else:
|
||
|
wrong_loc_count += 1
|
||
|
text = text.lstrip()
|
||
|
answer['text'] = text
|
||
|
starts = list(find_all(context, text))
|
||
|
if len(starts) == 1:
|
||
|
answer_start = starts[0]
|
||
|
elif len(starts) > 1:
|
||
|
new_answer_start = min(starts, key=lambda s: abs(s - answer_start))
|
||
|
loc_diffs.append(abs(new_answer_start - answer_start))
|
||
|
answer_start = new_answer_start
|
||
|
else:
|
||
|
raise Exception()
|
||
|
answer['answer_start'] = answer_start
|
||
|
|
||
|
answer_stop = answer_start + len(text)
|
||
|
answer['answer_stop'] = answer_stop
|
||
|
assert para['context'][answer_start:answer_stop] == answer['text'], "{} {}".format(
|
||
|
para['context'][answer_start:answer_stop], answer['text'])
|
||
|
|
||
|
print(wrong_loc_count, loc_diffs)
|
||
|
|
||
|
mismatch_count = 0
|
||
|
dep_fail_count = 0
|
||
|
no_answer_count = 0
|
||
|
|
||
|
size = sum(len(article['paragraphs']) for article in data['data'])
|
||
|
pbar = tqdm(range(size))
|
||
|
|
||
|
for ai, article in enumerate(data['data']):
|
||
|
for pi, para in enumerate(article['paragraphs']):
|
||
|
context = para['context']
|
||
|
sents = h.split_doc(context)
|
||
|
words = h.split_sent(context)
|
||
|
sent_starts = []
|
||
|
ref_idx = 0
|
||
|
for sent in sents:
|
||
|
new_idx = context.find(sent, ref_idx)
|
||
|
sent_starts.append(new_idx)
|
||
|
ref_idx = new_idx + len(sent)
|
||
|
para['sents'] = sents
|
||
|
para['words'] = words
|
||
|
para['sent_starts'] = sent_starts
|
||
|
|
||
|
consts = list(map(h.get_const, sents))
|
||
|
para['consts'] = consts
|
||
|
deps = list(map(h.get_dep, sents))
|
||
|
para['deps'] = deps
|
||
|
|
||
|
for qa in para['qas']:
|
||
|
question = qa['question']
|
||
|
question_const = h.get_const(question)
|
||
|
qa['const'] = question_const
|
||
|
question_dep = h.get_dep(question)
|
||
|
qa['dep'] = question_dep
|
||
|
qa['words'] = h.split_sent(question)
|
||
|
|
||
|
for answer in qa['answers']:
|
||
|
answer_start = answer['answer_start']
|
||
|
text = answer['text']
|
||
|
answer_stop = answer_start + len(text)
|
||
|
# answer_words = h.split_sent(text)
|
||
|
word_idxs = []
|
||
|
answer_words = []
|
||
|
for sent_idx, (sent, sent_start, dep) in enumerate(zip(sents, sent_starts, deps)):
|
||
|
if dep is None:
|
||
|
print("dep parse failed at {} {} {}".format(ai, pi, sent_idx))
|
||
|
dep_fail_count += 1
|
||
|
continue
|
||
|
nodes, edges = dep
|
||
|
words = [node[0] for node in nodes]
|
||
|
|
||
|
for word_idx, (word, _, _, start, _) in enumerate(nodes):
|
||
|
global_start = sent_start + start
|
||
|
global_stop = global_start + len(word)
|
||
|
if answer_start <= global_start < answer_stop or answer_start < global_stop <= answer_stop:
|
||
|
word_idxs.append((sent_idx, word_idx))
|
||
|
answer_words.append(word)
|
||
|
if len(word_idxs) > 0:
|
||
|
answer['answer_word_start'] = word_idxs[0]
|
||
|
answer['answer_word_stop'] = word_idxs[-1][0], word_idxs[-1][1] + 1
|
||
|
if not text.startswith(answer_words[0]):
|
||
|
print("'{}' '{}'".format(text, ' '.join(answer_words)))
|
||
|
mismatch_count += 1
|
||
|
else:
|
||
|
answer['answer_word_start'] = None
|
||
|
answer['answer_word_stop'] = None
|
||
|
no_answer_count += 1
|
||
|
pbar.update(1)
|
||
|
pbar.close()
|
||
|
|
||
|
print(mismatch_count, dep_fail_count, no_answer_count)
|
||
|
|
||
|
print("saving...")
|
||
|
json.dump(data, open(out_path, 'w'))
|