316 lines
13 KiB
Python
316 lines
13 KiB
Python
import json
|
|
import os
|
|
import random
|
|
import itertools
|
|
import math
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
|
|
from my.tensorflow import grouper
|
|
from my.utils import index
|
|
|
|
|
|
class Data(object):
|
|
def get_size(self):
|
|
raise NotImplementedError()
|
|
|
|
def get_by_idxs(self, idxs):
|
|
"""
|
|
Efficient way to obtain a batch of items from filesystem
|
|
:param idxs:
|
|
:return dict: {'X': [,], 'Y', }
|
|
"""
|
|
data = defaultdict(list)
|
|
for idx in idxs:
|
|
each_data = self.get_one(idx)
|
|
for key, val in each_data.items():
|
|
data[key].append(val)
|
|
return data
|
|
|
|
def get_one(self, idx):
|
|
raise NotImplementedError()
|
|
|
|
def get_empty(self):
|
|
raise NotImplementedError()
|
|
|
|
def __add__(self, other):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class DataSet(object):
|
|
def __init__(self, data, data_type, shared=None, valid_idxs=None):
|
|
self.data = data # e.g. {'X': [0, 1, 2], 'Y': [2, 3, 4]}
|
|
self.data_type = data_type
|
|
self.shared = shared
|
|
total_num_examples = self.get_data_size()
|
|
self.valid_idxs = range(total_num_examples) if valid_idxs is None else valid_idxs
|
|
self.num_examples = len(self.valid_idxs)
|
|
|
|
def _sort_key(self, idx):
|
|
rx = self.data['*x'][idx]
|
|
x = self.shared['x'][rx[0]][rx[1]]
|
|
return max(map(len, x))
|
|
|
|
def get_data_size(self):
|
|
if isinstance(self.data, dict):
|
|
return len(next(iter(self.data.values())))
|
|
elif isinstance(self.data, Data):
|
|
return self.data.get_size()
|
|
raise Exception()
|
|
|
|
def get_by_idxs(self, idxs):
|
|
if isinstance(self.data, dict):
|
|
out = defaultdict(list)
|
|
for key, val in self.data.items():
|
|
out[key].extend(val[idx] for idx in idxs)
|
|
return out
|
|
elif isinstance(self.data, Data):
|
|
return self.data.get_by_idxs(idxs)
|
|
raise Exception()
|
|
|
|
def get_batches(self, batch_size, num_batches=None, shuffle=False, cluster=False):
|
|
"""
|
|
|
|
:param batch_size:
|
|
:param num_batches:
|
|
:param shuffle:
|
|
:param cluster: cluster examples by their lengths; this might give performance boost (i.e. faster training).
|
|
:return:
|
|
"""
|
|
num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size))
|
|
if num_batches is None:
|
|
num_batches = num_batches_per_epoch
|
|
num_epochs = int(math.ceil(num_batches / num_batches_per_epoch))
|
|
|
|
if shuffle:
|
|
random_idxs = random.sample(self.valid_idxs, len(self.valid_idxs))
|
|
if cluster:
|
|
sorted_idxs = sorted(random_idxs, key=self._sort_key)
|
|
sorted_grouped = lambda: list(grouper(sorted_idxs, batch_size))
|
|
grouped = lambda: random.sample(sorted_grouped(), num_batches_per_epoch)
|
|
else:
|
|
random_grouped = lambda: list(grouper(random_idxs, batch_size))
|
|
grouped = random_grouped
|
|
else:
|
|
raw_grouped = lambda: list(grouper(self.valid_idxs, batch_size))
|
|
grouped = raw_grouped
|
|
|
|
batch_idx_tuples = itertools.chain.from_iterable(grouped() for _ in range(num_epochs))
|
|
for _ in range(num_batches):
|
|
batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None)
|
|
batch_data = self.get_by_idxs(batch_idxs)
|
|
shared_batch_data = {}
|
|
for key, val in batch_data.items():
|
|
if key.startswith('*'):
|
|
assert self.shared is not None
|
|
shared_key = key[1:]
|
|
shared_batch_data[shared_key] = [index(self.shared[shared_key], each) for each in val]
|
|
batch_data.update(shared_batch_data)
|
|
|
|
batch_ds = DataSet(batch_data, self.data_type, shared=self.shared)
|
|
yield batch_idxs, batch_ds
|
|
|
|
def get_multi_batches(self, batch_size, num_batches_per_step, num_steps=None, shuffle=False, cluster=False):
|
|
batch_size_per_step = batch_size * num_batches_per_step
|
|
batches = self.get_batches(batch_size_per_step, num_batches=num_steps, shuffle=shuffle, cluster=cluster)
|
|
multi_batches = (tuple(zip(grouper(idxs, batch_size, shorten=True, num_groups=num_batches_per_step),
|
|
data_set.divide(num_batches_per_step))) for idxs, data_set in batches)
|
|
return multi_batches
|
|
|
|
def get_empty(self):
|
|
if isinstance(self.data, dict):
|
|
data = {key: [] for key in self.data}
|
|
elif isinstance(self.data, Data):
|
|
data = self.data.get_empty()
|
|
else:
|
|
raise Exception()
|
|
return DataSet(data, self.data_type, shared=self.shared)
|
|
|
|
def __add__(self, other):
|
|
if isinstance(self.data, dict):
|
|
data = {key: val + other.data[key] for key, val in self.data.items()}
|
|
elif isinstance(self.data, Data):
|
|
data = self.data + other.data
|
|
else:
|
|
raise Exception()
|
|
|
|
valid_idxs = list(self.valid_idxs) + [valid_idx + self.num_examples for valid_idx in other.valid_idxs]
|
|
return DataSet(data, self.data_type, shared=self.shared, valid_idxs=valid_idxs)
|
|
|
|
def divide(self, integer):
|
|
batch_size = int(math.ceil(self.num_examples / integer))
|
|
idxs_gen = grouper(self.valid_idxs, batch_size, shorten=True, num_groups=integer)
|
|
data_gen = (self.get_by_idxs(idxs) for idxs in idxs_gen)
|
|
ds_tuple = tuple(DataSet(data, self.data_type, shared=self.shared) for data in data_gen)
|
|
return ds_tuple
|
|
|
|
|
|
def load_metadata(config, data_type):
|
|
metadata_path = os.path.join(config.data_dir, "metadata_{}.json".format(data_type))
|
|
with open(metadata_path, 'r') as fh:
|
|
metadata = json.load(fh)
|
|
for key, val in metadata.items():
|
|
config.__setattr__(key, val)
|
|
return metadata
|
|
|
|
|
|
def read_data(config, data_type, ref, data_filter=None):
|
|
data_path = os.path.join(config.data_dir, "data_{}.json".format(data_type))
|
|
shared_path = os.path.join(config.data_dir, "shared_{}.json".format(data_type))
|
|
with open(data_path, 'r') as fh:
|
|
data = json.load(fh)
|
|
with open(shared_path, 'r') as fh:
|
|
shared = json.load(fh)
|
|
|
|
num_examples = len(next(iter(data.values())))
|
|
if data_filter is None:
|
|
valid_idxs = range(num_examples)
|
|
else:
|
|
mask = []
|
|
keys = data.keys()
|
|
values = data.values()
|
|
for vals in zip(*values):
|
|
each = {key: val for key, val in zip(keys, vals)}
|
|
mask.append(data_filter(each, shared))
|
|
valid_idxs = [idx for idx in range(len(mask)) if mask[idx]]
|
|
|
|
print("Loaded {}/{} examples from {}".format(len(valid_idxs), num_examples, data_type))
|
|
|
|
shared_path = config.shared_path or os.path.join(config.out_dir, "shared.json")
|
|
if not ref:
|
|
word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
|
|
word_counter = shared['lower_word_counter'] if config.lower_word else shared['word_counter']
|
|
char_counter = shared['char_counter']
|
|
if config.finetune:
|
|
shared['word2idx'] = {word: idx + 2 for idx, word in
|
|
enumerate(word for word, count in word_counter.items()
|
|
if count > config.word_count_th or (config.known_if_glove and word in word2vec_dict))}
|
|
else:
|
|
assert config.known_if_glove
|
|
assert config.use_glove_for_unk
|
|
shared['word2idx'] = {word: idx + 2 for idx, word in
|
|
enumerate(word for word, count in word_counter.items()
|
|
if count > config.word_count_th and word not in word2vec_dict)}
|
|
shared['char2idx'] = {char: idx + 2 for idx, char in
|
|
enumerate(char for char, count in char_counter.items()
|
|
if count > config.char_count_th)}
|
|
NULL = "-NULL-"
|
|
UNK = "-UNK-"
|
|
shared['word2idx'][NULL] = 0
|
|
shared['word2idx'][UNK] = 1
|
|
shared['char2idx'][NULL] = 0
|
|
shared['char2idx'][UNK] = 1
|
|
json.dump({'word2idx': shared['word2idx'], 'char2idx': shared['char2idx']}, open(shared_path, 'w'))
|
|
else:
|
|
new_shared = json.load(open(shared_path, 'r'))
|
|
for key, val in new_shared.items():
|
|
shared[key] = val
|
|
|
|
if config.use_glove_for_unk:
|
|
# create new word2idx and word2vec
|
|
word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
|
|
new_word2idx_dict = {word: idx for idx, word in enumerate(word for word in word2vec_dict.keys() if word not in shared['word2idx'])}
|
|
shared['new_word2idx'] = new_word2idx_dict
|
|
offset = len(shared['word2idx'])
|
|
word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
|
|
new_word2idx_dict = shared['new_word2idx']
|
|
idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
|
|
# print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
|
|
new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
|
|
shared['new_emb_mat'] = new_emb_mat
|
|
|
|
data_set = DataSet(data, data_type, shared=shared, valid_idxs=valid_idxs)
|
|
return data_set
|
|
|
|
|
|
def get_squad_data_filter(config):
|
|
def data_filter(data_point, shared):
|
|
assert shared is not None
|
|
rx, rcx, q, cq, y = (data_point[key] for key in ('*x', '*cx', 'q', 'cq', 'y'))
|
|
x, cx = shared['x'], shared['cx']
|
|
if len(q) > config.ques_size_th:
|
|
return False
|
|
|
|
# x filter
|
|
xi = x[rx[0]][rx[1]]
|
|
if config.squash:
|
|
for start, stop in y:
|
|
stop_offset = sum(map(len, xi[:stop[0]]))
|
|
if stop_offset + stop[1] > config.para_size_th:
|
|
return False
|
|
return True
|
|
|
|
if config.single:
|
|
for start, stop in y:
|
|
if start[0] != stop[0]:
|
|
return False
|
|
|
|
if config.data_filter == 'max':
|
|
for start, stop in y:
|
|
if stop[0] >= config.num_sents_th:
|
|
return False
|
|
if start[0] != stop[0]:
|
|
return False
|
|
if stop[1] >= config.sent_size_th:
|
|
return False
|
|
elif config.data_filter == 'valid':
|
|
if len(xi) > config.num_sents_th:
|
|
return False
|
|
if any(len(xij) > config.sent_size_th for xij in xi):
|
|
return False
|
|
elif config.data_filter == 'semi':
|
|
"""
|
|
Only answer sentence needs to be valid.
|
|
"""
|
|
for start, stop in y:
|
|
if stop[0] >= config.num_sents_th:
|
|
return False
|
|
if start[0] != start[0]:
|
|
return False
|
|
if len(xi[start[0]]) > config.sent_size_th:
|
|
return False
|
|
else:
|
|
raise Exception()
|
|
|
|
return True
|
|
return data_filter
|
|
|
|
|
|
def update_config(config, data_sets):
|
|
config.max_num_sents = 0
|
|
config.max_sent_size = 0
|
|
config.max_ques_size = 0
|
|
config.max_word_size = 0
|
|
config.max_para_size = 0
|
|
for data_set in data_sets:
|
|
data = data_set.data
|
|
shared = data_set.shared
|
|
for idx in data_set.valid_idxs:
|
|
rx = data['*x'][idx]
|
|
q = data['q'][idx]
|
|
sents = shared['x'][rx[0]][rx[1]]
|
|
config.max_para_size = max(config.max_para_size, sum(map(len, sents)))
|
|
config.max_num_sents = max(config.max_num_sents, len(sents))
|
|
config.max_sent_size = max(config.max_sent_size, max(map(len, sents)))
|
|
config.max_word_size = max(config.max_word_size, max(len(word) for sent in sents for word in sent))
|
|
if len(q) > 0:
|
|
config.max_ques_size = max(config.max_ques_size, len(q))
|
|
config.max_word_size = max(config.max_word_size, max(len(word) for word in q))
|
|
|
|
if config.mode == 'train':
|
|
config.max_num_sents = min(config.max_num_sents, config.num_sents_th)
|
|
config.max_sent_size = min(config.max_sent_size, config.sent_size_th)
|
|
config.max_para_size = min(config.max_para_size, config.para_size_th)
|
|
|
|
config.max_word_size = min(config.max_word_size, config.word_size_th)
|
|
|
|
config.char_vocab_size = len(data_sets[0].shared['char2idx'])
|
|
config.word_emb_size = len(next(iter(data_sets[0].shared['word2vec'].values())))
|
|
config.word_vocab_size = len(data_sets[0].shared['word2idx'])
|
|
|
|
if config.single:
|
|
config.max_num_sents = 1
|
|
if config.squash:
|
|
config.max_sent_size = config.max_para_size
|
|
config.max_num_sents = 1
|