dawn-bench-models/tensorflow/SQuAD/squad/utils.py

147 lines
4.4 KiB
Python
Raw Normal View History

2017-08-17 12:43:17 -06:00
import re
import numpy as np
def get_2d_spans(text, tokenss):
spanss = []
cur_idx = 0
for tokens in tokenss:
spans = []
for token in tokens:
if text.find(token, cur_idx) < 0:
print(tokens)
print("{} {} {}".format(token, cur_idx, text))
raise Exception()
cur_idx = text.find(token, cur_idx)
spans.append((cur_idx, cur_idx + len(token)))
cur_idx += len(token)
spanss.append(spans)
return spanss
def get_word_span(context, wordss, start, stop):
spanss = get_2d_spans(context, wordss)
idxs = []
for sent_idx, spans in enumerate(spanss):
for word_idx, span in enumerate(spans):
if not (stop <= span[0] or start >= span[1]):
idxs.append((sent_idx, word_idx))
assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop)
return idxs[0], (idxs[-1][0], idxs[-1][1] + 1)
def get_phrase(context, wordss, span):
"""
Obtain phrase as substring of context given start and stop indices in word level
:param context:
:param wordss:
:param start: [sent_idx, word_idx]
:param stop: [sent_idx, word_idx]
:return:
"""
start, stop = span
flat_start = get_flat_idx(wordss, start)
flat_stop = get_flat_idx(wordss, stop)
words = sum(wordss, [])
char_idx = 0
char_start, char_stop = None, None
for word_idx, word in enumerate(words):
char_idx = context.find(word, char_idx)
assert char_idx >= 0
if word_idx == flat_start:
char_start = char_idx
char_idx += len(word)
if word_idx == flat_stop - 1:
char_stop = char_idx
assert char_start is not None
assert char_stop is not None
return context[char_start:char_stop]
def get_flat_idx(wordss, idx):
return sum(len(words) for words in wordss[:idx[0]]) + idx[1]
def get_word_idx(context, wordss, idx):
spanss = get_2d_spans(context, wordss)
return spanss[idx[0]][idx[1]][0]
def process_tokens(temp_tokens):
tokens = []
for token in temp_tokens:
flag = False
l = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0")
# \u2013 is en-dash. Used for number to nubmer
# l = ("-", "\u2212", "\u2014", "\u2013")
# l = ("\u2013",)
tokens.extend(re.split("([{}])".format("".join(l)), token))
return tokens
def get_best_span(ypi, yp2i):
max_val = 0
best_word_span = (0, 1)
best_sent_idx = 0
for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
argmax_j1 = 0
for j in range(len(ypif)):
val1 = ypif[argmax_j1]
if val1 < ypif[j]:
val1 = ypif[j]
argmax_j1 = j
val2 = yp2if[j]
if val1 * val2 > max_val:
best_word_span = (argmax_j1, j)
best_sent_idx = f
max_val = val1 * val2
return ((best_sent_idx, best_word_span[0]), (best_sent_idx, best_word_span[1] + 1)), float(max_val)
def get_best_span_wy(wypi, th):
chunk_spans = []
scores = []
chunk_start = None
score = 0
l = 0
th = min(th, np.max(wypi))
for f, wypif in enumerate(wypi):
for j, wypifj in enumerate(wypif):
if wypifj >= th:
if chunk_start is None:
chunk_start = f, j
score += wypifj
l += 1
else:
if chunk_start is not None:
chunk_stop = f, j
chunk_spans.append((chunk_start, chunk_stop))
scores.append(score/l)
score = 0
l = 0
chunk_start = None
if chunk_start is not None:
chunk_stop = f, j+1
chunk_spans.append((chunk_start, chunk_stop))
scores.append(score/l)
score = 0
l = 0
chunk_start = None
return max(zip(chunk_spans, scores), key=lambda pair: pair[1])
def get_span_score_pairs(ypi, yp2i):
span_score_pairs = []
for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
for j in range(len(ypif)):
for k in range(j, len(yp2if)):
span = ((f, j), (f, k+1))
score = ypif[j] * yp2if[k]
span_score_pairs.append((span, score))
return span_score_pairs