import nltk import numpy as np import tensorflow as tf from tensorflow.python.ops.rnn_cell import BasicLSTMCell from my.nltk_utils import tree2matrix, find_max_f1_subtree, load_compressed_tree, set_span from tree.read_data import DataSet from my.tensorflow import exp_mask, get_initializer from my.tensorflow.nn import linear from my.tensorflow.rnn import bidirectional_dynamic_rnn, dynamic_rnn from my.tensorflow.rnn_cell import SwitchableDropoutWrapper, NoOpCell, TreeRNNCell class Model(object): def __init__(self, config): self.config = config self.global_step = tf.get_variable('global_step', shape=[], dtype='int32', initializer=tf.constant_initializer(0), trainable=False) # Define forward inputs here N, M, JX, JQ, VW, VC, W, H = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.max_word_size, config.max_tree_height self.x = tf.placeholder('int32', [None, M, JX], name='x') self.cx = tf.placeholder('int32', [None, M, JX, W], name='cx') self.q = tf.placeholder('int32', [None, JQ], name='q') self.cq = tf.placeholder('int32', [None, JQ, W], name='cq') self.tx = tf.placeholder('int32', [None, M, H, JX], name='tx') self.tx_edge_mask = tf.placeholder('bool', [None, M, H, JX, JX], name='tx_edge_mask') self.y = tf.placeholder('bool', [None, M, H, JX], name='y') self.is_train = tf.placeholder('bool', [], name='is_train') # Define misc # Forward outputs / loss inputs self.logits = None self.yp = None self.var_list = None # Loss outputs self.loss = None self._build_forward() self._build_loss() self.ema_op = self._get_ema_op() self.summary = tf.summary.merge_all() def _build_forward(self): config = self.config N, M, JX, JQ, VW, VC, d, dc, W = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \ config.char_emb_size, config.max_word_size H = config.max_tree_height x_mask = self.x > 0 q_mask = self.q > 0 tx_mask = self.tx > 0 # [N, M, H, JX] with tf.variable_scope("char_emb"): char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float') Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx) # [N, M, JX, W, dc] Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq) # [N, JQ, W, dc] filter = tf.get_variable("filter", shape=[1, config.char_filter_height, dc, d], dtype='float') bias = tf.get_variable("bias", shape=[d], dtype='float') strides = [1, 1, 1, 1] Acx = tf.reshape(Acx, [-1, JX, W, dc]) Acq = tf.reshape(Acq, [-1, JQ, W, dc]) xxc = tf.nn.conv2d(Acx, filter, strides, "VALID") + bias # [N*M, JX, W/filter_stride, d] qqc = tf.nn.conv2d(Acq, filter, strides, "VALID") + bias # [N, JQ, W/filter_stride, d] xxc = tf.reshape(tf.reduce_max(tf.nn.relu(xxc), 2), [-1, M, JX, d]) qqc = tf.reshape(tf.reduce_max(tf.nn.relu(qqc), 2), [-1, JQ, d]) with tf.variable_scope("word_emb"): if config.mode == 'train': word_emb_mat = tf.get_variable("word_emb_mat", dtype='float', shape=[VW, config.word_emb_size], initializer=get_initializer(config.emb_mat)) else: word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, config.word_emb_size], dtype='float') Ax = tf.nn.embedding_lookup(word_emb_mat, self.x) # [N, M, JX, d] Aq = tf.nn.embedding_lookup(word_emb_mat, self.q) # [N, JQ, d] # Ax = linear([Ax], d, False, scope='Ax_reshape') # Aq = linear([Aq], d, False, scope='Aq_reshape') xx = tf.concat(axis=3, values=[xxc, Ax]) # [N, M, JX, 2d] qq = tf.concat(axis=2, values=[qqc, Aq]) # [N, JQ, 2d] D = d + config.word_emb_size with tf.variable_scope("pos_emb"): pos_emb_mat = tf.get_variable("pos_emb_mat", shape=[config.pos_vocab_size, d], dtype='float') Atx = tf.nn.embedding_lookup(pos_emb_mat, self.tx) # [N, M, H, JX, d] cell = BasicLSTMCell(D, state_is_tuple=True) cell = SwitchableDropoutWrapper(cell, self.is_train, input_keep_prob=config.input_keep_prob) x_len = tf.reduce_sum(tf.cast(x_mask, 'int32'), 2) # [N, M] q_len = tf.reduce_sum(tf.cast(q_mask, 'int32'), 1) # [N] with tf.variable_scope("rnn"): (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell, cell, xx, x_len, dtype='float', scope='start') # [N, M, JX, 2d] tf.get_variable_scope().reuse_variables() (fw_us, bw_us), (_, (fw_u, bw_u)) = bidirectional_dynamic_rnn(cell, cell, qq, q_len, dtype='float', scope='start') # [N, J, d], [N, d] u = (fw_u + bw_u) / 2.0 h = (fw_h + bw_h) / 2.0 with tf.variable_scope("h"): no_op_cell = NoOpCell(D) tree_rnn_cell = TreeRNNCell(no_op_cell, d, tf.reduce_max) initial_state = tf.reshape(h, [N*M*JX, D]) # [N*M*JX, D] inputs = tf.concat(axis=4, values=[Atx, tf.cast(self.tx_edge_mask, 'float')]) # [N, M, H, JX, d+JX] inputs = tf.reshape(tf.transpose(inputs, [0, 1, 3, 2, 4]), [N*M*JX, H, d + JX]) # [N*M*JX, H, d+JX] length = tf.reshape(tf.reduce_sum(tf.cast(tx_mask, 'int32'), 2), [N*M*JX]) # length = tf.reshape(tf.reduce_sum(tf.cast(tf.transpose(tx_mask, [0, 1, 3, 2]), 'float'), 3), [-1]) h, _ = dynamic_rnn(tree_rnn_cell, inputs, length, initial_state=initial_state) # [N*M*JX, H, D] h = tf.transpose(tf.reshape(h, [N, M, JX, H, D]), [0, 1, 3, 2, 4]) # [N, M, H, JX, D] u = tf.expand_dims(tf.expand_dims(tf.expand_dims(u, 1), 1), 1) # [N, 1, 1, 1, 4d] dot = linear(h * u, 1, True, squeeze=True, scope='dot') # [N, M, H, JX] # self.logits = tf.reshape(dot, [N, M * H * JX]) self.logits = tf.reshape(exp_mask(dot, tx_mask), [N, M * H * JX]) # [N, M, H, JX] self.yp = tf.reshape(tf.nn.softmax(self.logits), [N, M, H, JX]) def _build_loss(self): config = self.config N, M, JX, JQ, VW, VC = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size H = config.max_tree_height ce_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=self.logits, labels=tf.cast(tf.reshape(self.y, [N, M * H * JX]), 'float'))) tf.add_to_collection('losses', ce_loss) self.loss = tf.add_n(tf.get_collection('losses'), name='loss') tf.summary.scalar(self.loss.op.name, self.loss) tf.add_to_collection('ema/scalar', self.loss) def _get_ema_op(self): ema = tf.train.ExponentialMovingAverage(self.config.decay) ema_op = ema.apply(tf.get_collection("ema/scalar") + tf.get_collection("ema/histogram")) for var in tf.get_collection("ema/scalar"): ema_var = ema.average(var) tf.summary.scalar(ema_var.op.name, ema_var) for var in tf.get_collection("ema/histogram"): ema_var = ema.average(var) tf.summary.histogram(ema_var.op.name, ema_var) return ema_op def get_loss(self): return self.loss def get_global_step(self): return self.global_step def get_var_list(self): return self.var_list def get_feed_dict(self, batch, is_train, supervised=True): assert isinstance(batch, DataSet) config = self.config N, M, JX, JQ, VW, VC, d, W, H = \ config.batch_size, config.max_num_sents, config.max_sent_size, \ config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, config.max_word_size, \ config.max_tree_height feed_dict = {} x = np.zeros([N, M, JX], dtype='int32') cx = np.zeros([N, M, JX, W], dtype='int32') q = np.zeros([N, JQ], dtype='int32') cq = np.zeros([N, JQ, W], dtype='int32') tx = np.zeros([N, M, H, JX], dtype='int32') tx_edge_mask = np.zeros([N, M, H, JX, JX], dtype='bool') feed_dict[self.x] = x feed_dict[self.cx] = cx feed_dict[self.q] = q feed_dict[self.cq] = cq feed_dict[self.tx] = tx feed_dict[self.tx_edge_mask] = tx_edge_mask feed_dict[self.is_train] = is_train def _get_word(word): d = batch.shared['word2idx'] for each in (word, word.lower(), word.capitalize(), word.upper()): if each in d: return d[each] return 1 def _get_char(char): d = batch.shared['char2idx'] if char in d: return d[char] return 1 def _get_pos(tree): d = batch.shared['pos2idx'] if tree.label() in d: return d[tree.label()] return 1 for i, xi in enumerate(batch.data['x']): for j, xij in enumerate(xi): for k, xijk in enumerate(xij): x[i, j, k] = _get_word(xijk) for i, cxi in enumerate(batch.data['cx']): for j, cxij in enumerate(cxi): for k, cxijk in enumerate(cxij): for l, cxijkl in enumerate(cxijk): cx[i, j, k, l] = _get_char(cxijkl) if l + 1 == config.max_word_size: break for i, qi in enumerate(batch.data['q']): for j, qij in enumerate(qi): q[i, j] = _get_word(qij) for i, cqi in enumerate(batch.data['cq']): for j, cqij in enumerate(cqi): for k, cqijk in enumerate(cqij): cq[i, j, k] = _get_char(cqijk) if k + 1 == config.max_word_size: break for i, txi in enumerate(batch.data['stx']): for j, txij in enumerate(txi): txij_mat, txij_mask = tree2matrix(nltk.tree.Tree.fromstring(txij), _get_pos, row_size=H, col_size=JX) tx[i, j, :, :], tx_edge_mask[i, j, :, :, :] = txij_mat, txij_mask if supervised: y = np.zeros([N, M, H, JX], dtype='bool') feed_dict[self.y] = y for i, yi in enumerate(batch.data['y']): start_idx, stop_idx = yi sent_idx = start_idx[0] if start_idx[0] == stop_idx[0]: span = [start_idx[1], stop_idx[1]] else: span = [start_idx[1], len(batch.data['x'][sent_idx])] tree = nltk.tree.Tree.fromstring(batch.data['stx'][i][sent_idx]) set_span(tree) best_subtree = find_max_f1_subtree(tree, span) def _get_y(t): return t == best_subtree yij, _ = tree2matrix(tree, _get_y, H, JX, dtype='bool') y[i, sent_idx, :, :] = yij return feed_dict