47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
import tensorflow as tf
|
|
from tensorflow.python.ops.rnn_cell import RNNCell
|
|
|
|
from my.tensorflow.nn import linear
|
|
|
|
|
|
class SHCell(RNNCell):
|
|
"""
|
|
Super-Highway Cell
|
|
"""
|
|
def __init__(self, input_size, logit_func='tri_linear', scalar=False):
|
|
self._state_size = input_size
|
|
self._output_size = input_size
|
|
self._logit_func = logit_func
|
|
self._scalar = scalar
|
|
|
|
@property
|
|
def state_size(self):
|
|
return self._state_size
|
|
|
|
@property
|
|
def output_size(self):
|
|
return self._output_size
|
|
|
|
def __call__(self, inputs, state, scope=None):
|
|
with tf.variable_scope(scope or "SHCell"):
|
|
a_size = 1 if self._scalar else self._state_size
|
|
h, u = tf.split(axis=1, num_or_size_splits=2, value=inputs)
|
|
if self._logit_func == 'mul_linear':
|
|
args = [h * u, state * u]
|
|
a = tf.nn.sigmoid(linear(args, a_size, True))
|
|
elif self._logit_func == 'linear':
|
|
args = [h, u, state]
|
|
a = tf.nn.sigmoid(linear(args, a_size, True))
|
|
elif self._logit_func == 'tri_linear':
|
|
args = [h, u, state, h * u, state * u]
|
|
a = tf.nn.sigmoid(linear(args, a_size, True))
|
|
elif self._logit_func == 'double':
|
|
args = [h, u, state]
|
|
a = tf.nn.sigmoid(linear(tf.tanh(linear(args, a_size, True)), self._state_size, True))
|
|
|
|
else:
|
|
raise Exception()
|
|
new_state = a * state + (1 - a) * h
|
|
outputs = state
|
|
return outputs, new_state
|
|
|