dawn-bench-models/tensorflow/SQuAD/tree/read_data.py
Deepak Narayanan b7e1e0fa0f First commit
2017-08-17 11:43:17 -07:00

159 lines
6.6 KiB
Python

import json
import os
import random
import itertools
import math
import nltk
from my.nltk_utils import load_compressed_tree
from my.utils import index
class DataSet(object):
def __init__(self, data, data_type, shared=None, valid_idxs=None):
total_num_examples = len(next(iter(data.values())))
self.data = data # e.g. {'X': [0, 1, 2], 'Y': [2, 3, 4]}
self.data_type = data_type
self.shared = shared
self.valid_idxs = range(total_num_examples) if valid_idxs is None else valid_idxs
self.num_examples = len(self.valid_idxs)
def get_batches(self, batch_size, num_batches=None, shuffle=False):
num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size))
if num_batches is None:
num_batches = num_batches_per_epoch
num_epochs = int(math.ceil(num_batches / num_batches_per_epoch))
idxs = itertools.chain.from_iterable(random.sample(self.valid_idxs, len(self.valid_idxs))
if shuffle else self.valid_idxs
for _ in range(num_epochs))
for _ in range(num_batches):
batch_idxs = tuple(itertools.islice(idxs, batch_size))
batch_data = {}
for key, val in self.data.items():
if key.startswith('*'):
assert self.shared is not None
shared_key = key[1:]
batch_data[shared_key] = [index(self.shared[shared_key], val[idx]) for idx in batch_idxs]
else:
batch_data[key] = list(map(val.__getitem__, batch_idxs))
batch_ds = DataSet(batch_data, self.data_type, shared=self.shared)
yield batch_idxs, batch_ds
class SquadDataSet(DataSet):
def __init__(self, data, data_type, shared=None, valid_idxs=None):
super(SquadDataSet, self).__init__(data, data_type, shared=shared, valid_idxs=valid_idxs)
def load_metadata(config, data_type):
metadata_path = os.path.join(config.data_dir, "metadata_{}.json".format(data_type))
with open(metadata_path, 'r') as fh:
metadata = json.load(fh)
for key, val in metadata.items():
config.__setattr__(key, val)
return metadata
def read_data(config, data_type, ref, data_filter=None):
data_path = os.path.join(config.data_dir, "data_{}.json".format(data_type))
shared_path = os.path.join(config.data_dir, "shared_{}.json".format(data_type))
with open(data_path, 'r') as fh:
data = json.load(fh)
with open(shared_path, 'r') as fh:
shared = json.load(fh)
num_examples = len(next(iter(data.values())))
if data_filter is None:
valid_idxs = range(num_examples)
else:
mask = []
keys = data.keys()
values = data.values()
for vals in zip(*values):
each = {key: val for key, val in zip(keys, vals)}
mask.append(data_filter(each, shared))
valid_idxs = [idx for idx in range(len(mask)) if mask[idx]]
print("Loaded {}/{} examples from {}".format(len(valid_idxs), num_examples, data_type))
shared_path = os.path.join(config.out_dir, "shared.json")
if not ref:
word_counter = shared['lower_word_counter'] if config.lower_word else shared['word_counter']
char_counter = shared['char_counter']
pos_counter = shared['pos_counter']
shared['word2idx'] = {word: idx + 2 for idx, word in
enumerate(word for word, count in word_counter.items()
if count > config.word_count_th)}
shared['char2idx'] = {char: idx + 2 for idx, char in
enumerate(char for char, count in char_counter.items()
if count > config.char_count_th)}
shared['pos2idx'] = {pos: idx + 2 for idx, pos in enumerate(pos_counter.keys())}
NULL = "-NULL-"
UNK = "-UNK-"
shared['word2idx'][NULL] = 0
shared['word2idx'][UNK] = 1
shared['char2idx'][NULL] = 0
shared['char2idx'][UNK] = 1
shared['pos2idx'][NULL] = 0
shared['pos2idx'][UNK] = 1
json.dump({'word2idx': shared['word2idx'], 'char2idx': shared['char2idx'],
'pos2idx': shared['pos2idx']}, open(shared_path, 'w'))
else:
new_shared = json.load(open(shared_path, 'r'))
for key, val in new_shared.items():
shared[key] = val
data_set = DataSet(data, data_type, shared=shared, valid_idxs=valid_idxs)
return data_set
def get_squad_data_filter(config):
def data_filter(data_point, shared):
assert shared is not None
rx, rcx, q, cq, y = (data_point[key] for key in ('*x', '*cx', 'q', 'cq', 'y'))
x, cx, stx = shared['x'], shared['cx'], shared['stx']
if len(q) > config.ques_size_th:
return False
xi = x[rx[0]][rx[1]]
if len(xi) > config.num_sents_th:
return False
if any(len(xij) > config.sent_size_th for xij in xi):
return False
stxi = stx[rx[0]][rx[1]]
if any(nltk.tree.Tree.fromstring(s).height() > config.tree_height_th for s in stxi):
return False
return True
return data_filter
def update_config(config, data_sets):
config.max_num_sents = 0
config.max_sent_size = 0
config.max_ques_size = 0
config.max_word_size = 0
config.max_tree_height = 0
for data_set in data_sets:
data = data_set.data
shared = data_set.shared
for idx in data_set.valid_idxs:
rx = data['*x'][idx]
q = data['q'][idx]
sents = shared['x'][rx[0]][rx[1]]
trees = map(nltk.tree.Tree.fromstring, shared['stx'][rx[0]][rx[1]])
config.max_tree_height = max(config.max_tree_height, max(tree.height() for tree in trees))
config.max_num_sents = max(config.max_num_sents, len(sents))
config.max_sent_size = max(config.max_sent_size, max(map(len, sents)))
config.max_word_size = max(config.max_word_size, max(len(word) for sent in sents for word in sent))
if len(q) > 0:
config.max_ques_size = max(config.max_ques_size, len(q))
config.max_word_size = max(config.max_word_size, max(len(word) for word in q))
config.max_word_size = min(config.max_word_size, config.word_size_th)
config.char_vocab_size = len(data_sets[0].shared['char2idx'])
config.word_emb_size = len(next(iter(data_sets[0].shared['word2vec'].values())))
config.word_vocab_size = len(data_sets[0].shared['word2idx'])
config.pos_vocab_size = len(data_sets[0].shared['pos2idx'])