diff --git a/privacy/analysis/tensor_buffer.py b/privacy/analysis/tensor_buffer.py index 520f80e..a0cf665 100644 --- a/privacy/analysis/tensor_buffer.py +++ b/privacy/analysis/tensor_buffer.py @@ -59,9 +59,12 @@ class TensorBuffer(object): name='buffer', use_resource=True) self._current_size = tf.Variable( - initial_value=0, trainable=False, name='current_size') + initial_value=0, dtype=tf.int32, trainable=False, name='current_size') self._capacity = tf.Variable( - initial_value=capacity, trainable=False, name='capacity') + initial_value=capacity, + dtype=tf.int32, + trainable=False, + name='capacity') def append(self, value): """Appends a new tensor to the end of the buffer.