146 lines
4.4 KiB
Python
146 lines
4.4 KiB
Python
import re
|
|
import numpy as np
|
|
|
|
|
|
def get_2d_spans(text, tokenss):
|
|
spanss = []
|
|
cur_idx = 0
|
|
for tokens in tokenss:
|
|
spans = []
|
|
for token in tokens:
|
|
if text.find(token, cur_idx) < 0:
|
|
print(tokens)
|
|
print("{} {} {}".format(token, cur_idx, text))
|
|
raise Exception()
|
|
cur_idx = text.find(token, cur_idx)
|
|
spans.append((cur_idx, cur_idx + len(token)))
|
|
cur_idx += len(token)
|
|
spanss.append(spans)
|
|
return spanss
|
|
|
|
|
|
def get_word_span(context, wordss, start, stop):
|
|
spanss = get_2d_spans(context, wordss)
|
|
idxs = []
|
|
for sent_idx, spans in enumerate(spanss):
|
|
for word_idx, span in enumerate(spans):
|
|
if not (stop <= span[0] or start >= span[1]):
|
|
idxs.append((sent_idx, word_idx))
|
|
|
|
assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop)
|
|
return idxs[0], (idxs[-1][0], idxs[-1][1] + 1)
|
|
|
|
|
|
def get_phrase(context, wordss, span):
|
|
"""
|
|
Obtain phrase as substring of context given start and stop indices in word level
|
|
:param context:
|
|
:param wordss:
|
|
:param start: [sent_idx, word_idx]
|
|
:param stop: [sent_idx, word_idx]
|
|
:return:
|
|
"""
|
|
start, stop = span
|
|
flat_start = get_flat_idx(wordss, start)
|
|
flat_stop = get_flat_idx(wordss, stop)
|
|
words = sum(wordss, [])
|
|
char_idx = 0
|
|
char_start, char_stop = None, None
|
|
for word_idx, word in enumerate(words):
|
|
char_idx = context.find(word, char_idx)
|
|
assert char_idx >= 0
|
|
if word_idx == flat_start:
|
|
char_start = char_idx
|
|
char_idx += len(word)
|
|
if word_idx == flat_stop - 1:
|
|
char_stop = char_idx
|
|
assert char_start is not None
|
|
assert char_stop is not None
|
|
return context[char_start:char_stop]
|
|
|
|
|
|
def get_flat_idx(wordss, idx):
|
|
return sum(len(words) for words in wordss[:idx[0]]) + idx[1]
|
|
|
|
|
|
def get_word_idx(context, wordss, idx):
|
|
spanss = get_2d_spans(context, wordss)
|
|
return spanss[idx[0]][idx[1]][0]
|
|
|
|
|
|
def process_tokens(temp_tokens):
|
|
tokens = []
|
|
for token in temp_tokens:
|
|
flag = False
|
|
l = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0")
|
|
# \u2013 is en-dash. Used for number to nubmer
|
|
# l = ("-", "\u2212", "\u2014", "\u2013")
|
|
# l = ("\u2013",)
|
|
tokens.extend(re.split("([{}])".format("".join(l)), token))
|
|
return tokens
|
|
|
|
|
|
def get_best_span(ypi, yp2i):
|
|
max_val = 0
|
|
best_word_span = (0, 1)
|
|
best_sent_idx = 0
|
|
for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
|
|
argmax_j1 = 0
|
|
for j in range(len(ypif)):
|
|
val1 = ypif[argmax_j1]
|
|
if val1 < ypif[j]:
|
|
val1 = ypif[j]
|
|
argmax_j1 = j
|
|
|
|
val2 = yp2if[j]
|
|
if val1 * val2 > max_val:
|
|
best_word_span = (argmax_j1, j)
|
|
best_sent_idx = f
|
|
max_val = val1 * val2
|
|
return ((best_sent_idx, best_word_span[0]), (best_sent_idx, best_word_span[1] + 1)), float(max_val)
|
|
|
|
|
|
def get_best_span_wy(wypi, th):
|
|
chunk_spans = []
|
|
scores = []
|
|
chunk_start = None
|
|
score = 0
|
|
l = 0
|
|
th = min(th, np.max(wypi))
|
|
for f, wypif in enumerate(wypi):
|
|
for j, wypifj in enumerate(wypif):
|
|
if wypifj >= th:
|
|
if chunk_start is None:
|
|
chunk_start = f, j
|
|
score += wypifj
|
|
l += 1
|
|
else:
|
|
if chunk_start is not None:
|
|
chunk_stop = f, j
|
|
chunk_spans.append((chunk_start, chunk_stop))
|
|
scores.append(score/l)
|
|
score = 0
|
|
l = 0
|
|
chunk_start = None
|
|
if chunk_start is not None:
|
|
chunk_stop = f, j+1
|
|
chunk_spans.append((chunk_start, chunk_stop))
|
|
scores.append(score/l)
|
|
score = 0
|
|
l = 0
|
|
chunk_start = None
|
|
|
|
return max(zip(chunk_spans, scores), key=lambda pair: pair[1])
|
|
|
|
|
|
def get_span_score_pairs(ypi, yp2i):
|
|
span_score_pairs = []
|
|
for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
|
|
for j in range(len(ypif)):
|
|
for k in range(j, len(yp2if)):
|
|
span = ((f, j), (f, k+1))
|
|
score = ypif[j] * yp2if[k]
|
|
span_score_pairs.append((span, score))
|
|
return span_score_pairs
|
|
|
|
|