160 lines
6.6 KiB
Python
160 lines
6.6 KiB
Python
|
import json
|
||
|
import os
|
||
|
import random
|
||
|
import itertools
|
||
|
import math
|
||
|
|
||
|
import nltk
|
||
|
|
||
|
from my.nltk_utils import load_compressed_tree
|
||
|
from my.utils import index
|
||
|
|
||
|
|
||
|
class DataSet(object):
|
||
|
def __init__(self, data, data_type, shared=None, valid_idxs=None):
|
||
|
total_num_examples = len(next(iter(data.values())))
|
||
|
self.data = data # e.g. {'X': [0, 1, 2], 'Y': [2, 3, 4]}
|
||
|
self.data_type = data_type
|
||
|
self.shared = shared
|
||
|
self.valid_idxs = range(total_num_examples) if valid_idxs is None else valid_idxs
|
||
|
self.num_examples = len(self.valid_idxs)
|
||
|
|
||
|
def get_batches(self, batch_size, num_batches=None, shuffle=False):
|
||
|
num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size))
|
||
|
if num_batches is None:
|
||
|
num_batches = num_batches_per_epoch
|
||
|
num_epochs = int(math.ceil(num_batches / num_batches_per_epoch))
|
||
|
|
||
|
idxs = itertools.chain.from_iterable(random.sample(self.valid_idxs, len(self.valid_idxs))
|
||
|
if shuffle else self.valid_idxs
|
||
|
for _ in range(num_epochs))
|
||
|
for _ in range(num_batches):
|
||
|
batch_idxs = tuple(itertools.islice(idxs, batch_size))
|
||
|
batch_data = {}
|
||
|
for key, val in self.data.items():
|
||
|
if key.startswith('*'):
|
||
|
assert self.shared is not None
|
||
|
shared_key = key[1:]
|
||
|
batch_data[shared_key] = [index(self.shared[shared_key], val[idx]) for idx in batch_idxs]
|
||
|
else:
|
||
|
batch_data[key] = list(map(val.__getitem__, batch_idxs))
|
||
|
|
||
|
batch_ds = DataSet(batch_data, self.data_type, shared=self.shared)
|
||
|
yield batch_idxs, batch_ds
|
||
|
|
||
|
|
||
|
class SquadDataSet(DataSet):
|
||
|
def __init__(self, data, data_type, shared=None, valid_idxs=None):
|
||
|
super(SquadDataSet, self).__init__(data, data_type, shared=shared, valid_idxs=valid_idxs)
|
||
|
|
||
|
|
||
|
def load_metadata(config, data_type):
|
||
|
metadata_path = os.path.join(config.data_dir, "metadata_{}.json".format(data_type))
|
||
|
with open(metadata_path, 'r') as fh:
|
||
|
metadata = json.load(fh)
|
||
|
for key, val in metadata.items():
|
||
|
config.__setattr__(key, val)
|
||
|
return metadata
|
||
|
|
||
|
|
||
|
def read_data(config, data_type, ref, data_filter=None):
|
||
|
data_path = os.path.join(config.data_dir, "data_{}.json".format(data_type))
|
||
|
shared_path = os.path.join(config.data_dir, "shared_{}.json".format(data_type))
|
||
|
with open(data_path, 'r') as fh:
|
||
|
data = json.load(fh)
|
||
|
with open(shared_path, 'r') as fh:
|
||
|
shared = json.load(fh)
|
||
|
|
||
|
num_examples = len(next(iter(data.values())))
|
||
|
if data_filter is None:
|
||
|
valid_idxs = range(num_examples)
|
||
|
else:
|
||
|
mask = []
|
||
|
keys = data.keys()
|
||
|
values = data.values()
|
||
|
for vals in zip(*values):
|
||
|
each = {key: val for key, val in zip(keys, vals)}
|
||
|
mask.append(data_filter(each, shared))
|
||
|
valid_idxs = [idx for idx in range(len(mask)) if mask[idx]]
|
||
|
|
||
|
print("Loaded {}/{} examples from {}".format(len(valid_idxs), num_examples, data_type))
|
||
|
|
||
|
shared_path = os.path.join(config.out_dir, "shared.json")
|
||
|
if not ref:
|
||
|
word_counter = shared['lower_word_counter'] if config.lower_word else shared['word_counter']
|
||
|
char_counter = shared['char_counter']
|
||
|
pos_counter = shared['pos_counter']
|
||
|
shared['word2idx'] = {word: idx + 2 for idx, word in
|
||
|
enumerate(word for word, count in word_counter.items()
|
||
|
if count > config.word_count_th)}
|
||
|
shared['char2idx'] = {char: idx + 2 for idx, char in
|
||
|
enumerate(char for char, count in char_counter.items()
|
||
|
if count > config.char_count_th)}
|
||
|
shared['pos2idx'] = {pos: idx + 2 for idx, pos in enumerate(pos_counter.keys())}
|
||
|
NULL = "-NULL-"
|
||
|
UNK = "-UNK-"
|
||
|
shared['word2idx'][NULL] = 0
|
||
|
shared['word2idx'][UNK] = 1
|
||
|
shared['char2idx'][NULL] = 0
|
||
|
shared['char2idx'][UNK] = 1
|
||
|
shared['pos2idx'][NULL] = 0
|
||
|
shared['pos2idx'][UNK] = 1
|
||
|
json.dump({'word2idx': shared['word2idx'], 'char2idx': shared['char2idx'],
|
||
|
'pos2idx': shared['pos2idx']}, open(shared_path, 'w'))
|
||
|
else:
|
||
|
new_shared = json.load(open(shared_path, 'r'))
|
||
|
for key, val in new_shared.items():
|
||
|
shared[key] = val
|
||
|
|
||
|
data_set = DataSet(data, data_type, shared=shared, valid_idxs=valid_idxs)
|
||
|
return data_set
|
||
|
|
||
|
|
||
|
def get_squad_data_filter(config):
|
||
|
def data_filter(data_point, shared):
|
||
|
assert shared is not None
|
||
|
rx, rcx, q, cq, y = (data_point[key] for key in ('*x', '*cx', 'q', 'cq', 'y'))
|
||
|
x, cx, stx = shared['x'], shared['cx'], shared['stx']
|
||
|
if len(q) > config.ques_size_th:
|
||
|
return False
|
||
|
xi = x[rx[0]][rx[1]]
|
||
|
if len(xi) > config.num_sents_th:
|
||
|
return False
|
||
|
if any(len(xij) > config.sent_size_th for xij in xi):
|
||
|
return False
|
||
|
stxi = stx[rx[0]][rx[1]]
|
||
|
if any(nltk.tree.Tree.fromstring(s).height() > config.tree_height_th for s in stxi):
|
||
|
return False
|
||
|
return True
|
||
|
return data_filter
|
||
|
|
||
|
|
||
|
def update_config(config, data_sets):
|
||
|
config.max_num_sents = 0
|
||
|
config.max_sent_size = 0
|
||
|
config.max_ques_size = 0
|
||
|
config.max_word_size = 0
|
||
|
config.max_tree_height = 0
|
||
|
for data_set in data_sets:
|
||
|
data = data_set.data
|
||
|
shared = data_set.shared
|
||
|
for idx in data_set.valid_idxs:
|
||
|
rx = data['*x'][idx]
|
||
|
q = data['q'][idx]
|
||
|
sents = shared['x'][rx[0]][rx[1]]
|
||
|
trees = map(nltk.tree.Tree.fromstring, shared['stx'][rx[0]][rx[1]])
|
||
|
config.max_tree_height = max(config.max_tree_height, max(tree.height() for tree in trees))
|
||
|
config.max_num_sents = max(config.max_num_sents, len(sents))
|
||
|
config.max_sent_size = max(config.max_sent_size, max(map(len, sents)))
|
||
|
config.max_word_size = max(config.max_word_size, max(len(word) for sent in sents for word in sent))
|
||
|
if len(q) > 0:
|
||
|
config.max_ques_size = max(config.max_ques_size, len(q))
|
||
|
config.max_word_size = max(config.max_word_size, max(len(word) for word in q))
|
||
|
|
||
|
config.max_word_size = min(config.max_word_size, config.word_size_th)
|
||
|
|
||
|
config.char_vocab_size = len(data_sets[0].shared['char2idx'])
|
||
|
config.word_emb_size = len(next(iter(data_sets[0].shared['word2vec'].values())))
|
||
|
config.word_vocab_size = len(data_sets[0].shared['word2idx'])
|
||
|
config.pos_vocab_size = len(data_sets[0].shared['pos2idx'])
|