From 28639ba0a8680b508778cc259817250f50d07f69 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Fri, 3 May 2019 16:29:40 -0700 Subject: [PATCH] Allow tensor buffers to automatically resize as needed. PiperOrigin-RevId: 246594454 --- privacy/analysis/tensor_buffer.py | 96 +++++++++++++------ ...er_test.py => tensor_buffer_test_eager.py} | 31 ++++-- privacy/analysis/tensor_buffer_test_graph.py | 72 ++++++++++++++ 3 files changed, 161 insertions(+), 38 deletions(-) rename privacy/analysis/{tensor_buffer_test.py => tensor_buffer_test_eager.py} (66%) create mode 100644 privacy/analysis/tensor_buffer_test_graph.py diff --git a/privacy/analysis/tensor_buffer.py b/privacy/analysis/tensor_buffer.py index 1b2341d..d5965ea 100644 --- a/privacy/analysis/tensor_buffer.py +++ b/privacy/analysis/tensor_buffer.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A lightweight fixed-sized buffer for maintaining lists. -""" +"""A lightweight buffer for maintaining tensors.""" from __future__ import absolute_import from __future__ import division @@ -22,7 +21,7 @@ import tensorflow as tf class TensorBuffer(object): - """A lightweight fixed-sized buffer for maintaining lists. + """A lightweight buffer for maintaining lists. The TensorBuffer accumulates tensors of the given shape into a tensor (whose rank is one more than that of the given shape) via calls to `append`. The @@ -30,12 +29,12 @@ class TensorBuffer(object): `values`. """ - def __init__(self, max_size, shape, dtype=tf.int32, name=None): + def __init__(self, capacity, shape, dtype=tf.int32, name=None): """Initializes the TensorBuffer. Args: - max_size: The maximum size. Attempts to append more than this many rows - will fail with an exception. + capacity: Initial capacity. Buffer will double in capacity each time it is + filled to capacity. shape: The shape (as tuple or list) of the tensors to accumulate. dtype: The type of the tensors. name: A string name for the variable_scope used. @@ -45,19 +44,24 @@ class TensorBuffer(object): """ shape = list(shape) self._rank = len(shape) + self._name = name + self._dtype = dtype if not self._rank: raise ValueError('Shape cannot be scalar.') - shape = [max_size] + shape + shape = [capacity] + shape - with tf.variable_scope(name): + with tf.variable_scope(self._name): + # We need to use a placeholder as the initial value to allow resizing. self._buffer = tf.Variable( - initial_value=tf.zeros(shape, dtype), + initial_value=tf.placeholder_with_default( + tf.zeros(shape, dtype), shape=None), trainable=False, - name='buffer') - self._size = tf.Variable( - initial_value=0, - trainable=False, - name='size') + name='buffer', + use_resource=True) + self._current_size = tf.Variable( + initial_value=0, trainable=False, name='current_size') + self._capacity = tf.Variable( + initial_value=capacity, trainable=False, name='capacity') def append(self, value): """Appends a new tensor to the end of the buffer. @@ -69,23 +73,59 @@ class TensorBuffer(object): Returns: An op appending the new tensor to the end of the buffer. """ - with tf.control_dependencies([ - tf.assert_less( - self._size, - tf.shape(self._buffer)[0], - message='Appending past end of TensorBuffer.'), - tf.assert_equal( - tf.shape(value), - tf.shape(self._buffer)[1:], - message='Appending value of inconsistent shape.')]): - with tf.control_dependencies( - [tf.assign(self._buffer[self._size, :], value)]): - return tf.assign_add(self._size, 1) + + def _double_capacity(): + """Doubles the capacity of the current tensor buffer.""" + padding = tf.zeros_like(self._buffer, self._buffer.dtype) + new_buffer = tf.concat([self._buffer, padding], axis=0) + if tf.executing_eagerly(): + with tf.variable_scope(self._name, reuse=True): + self._buffer = tf.get_variable( + name='buffer', + dtype=self._dtype, + initializer=new_buffer, + trainable=False) + return self._buffer, tf.assign(self._capacity, + tf.multiply(self._capacity, 2)) + else: + return tf.assign( + self._buffer, new_buffer, + validate_shape=False), tf.assign(self._capacity, + tf.multiply(self._capacity, 2)) + + update_buffer, update_capacity = tf.cond( + tf.equal(self._current_size, self._capacity), + _double_capacity, lambda: (self._buffer, self._capacity)) + + with tf.control_dependencies([update_buffer, update_capacity]): + with tf.control_dependencies([ + tf.assert_less( + self._current_size, + self._capacity, + message='Appending past end of TensorBuffer.'), + tf.assert_equal( + tf.shape(value), + tf.shape(self._buffer)[1:], + message='Appending value of inconsistent shape.') + ]): + with tf.control_dependencies( + [tf.assign(self._buffer[self._current_size, :], value)]): + return tf.assign_add(self._current_size, 1) @property def values(self): """Returns the accumulated tensor.""" begin_value = tf.zeros([self._rank + 1], dtype=tf.int32) - value_size = tf.concat( - [[self._size], tf.constant(-1, tf.int32, [self._rank])], 0) + value_size = tf.concat([[self._current_size], + tf.constant(-1, tf.int32, [self._rank])], 0) return tf.slice(self._buffer, begin_value, value_size) + + @property + def current_size(self): + """Returns the current number of tensors in the buffer.""" + return self._current_size + + @property + def capacity(self): + """Returns the current capacity of the buffer.""" + return self._capacity diff --git a/privacy/analysis/tensor_buffer_test.py b/privacy/analysis/tensor_buffer_test_eager.py similarity index 66% rename from privacy/analysis/tensor_buffer_test.py rename to privacy/analysis/tensor_buffer_test_eager.py index 31acb5f..c5a4900 100644 --- a/privacy/analysis/tensor_buffer_test.py +++ b/privacy/analysis/tensor_buffer_test_eager.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for tensor_buffer.""" +"""Tests for tensor_buffer in eager mode.""" from __future__ import absolute_import from __future__ import division @@ -25,6 +25,7 @@ tf.enable_eager_execution() class TensorBufferTest(tf.test.TestCase): + """Tests for TensorBuffer in eager mode.""" def test_basic(self): size, shape = 2, [2, 3] @@ -53,20 +54,30 @@ class TensorBufferTest(tf.test.TestCase): 'Appending value of inconsistent shape.'): my_buffer.append(tf.ones(shape=[3, 4], dtype=tf.int32)) - def test_fail_on_overflow(self): + def test_resize(self): size, shape = 2, [2, 3] my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') - # First two should succeed. - my_buffer.append(tf.ones(shape=shape, dtype=tf.int32)) - my_buffer.append(tf.ones(shape=shape, dtype=tf.int32)) + # Append three buffers. Third one should succeed after resizing. + value1 = [[1, 2, 3], [4, 5, 6]] + my_buffer.append(value1) + self.assertAllEqual(my_buffer.values.numpy(), [value1]) + self.assertAllEqual(my_buffer.current_size.numpy(), 1) + self.assertAllEqual(my_buffer.capacity.numpy(), 2) - # Third one should fail. - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - 'Appending past end of TensorBuffer.'): - my_buffer.append(tf.ones(shape=shape, dtype=tf.int32)) + value2 = [[4, 5, 6], [7, 8, 9]] + my_buffer.append(value2) + self.assertAllEqual(my_buffer.values.numpy(), [value1, value2]) + self.assertAllEqual(my_buffer.current_size.numpy(), 2) + self.assertAllEqual(my_buffer.capacity.numpy(), 2) + + value3 = [[7, 8, 9], [10, 11, 12]] + my_buffer.append(value3) + self.assertAllEqual(my_buffer.values.numpy(), [value1, value2, value3]) + self.assertAllEqual(my_buffer.current_size.numpy(), 3) + # Capacity should have doubled. + self.assertAllEqual(my_buffer.capacity.numpy(), 4) if __name__ == '__main__': diff --git a/privacy/analysis/tensor_buffer_test_graph.py b/privacy/analysis/tensor_buffer_test_graph.py new file mode 100644 index 0000000..b68fe53 --- /dev/null +++ b/privacy/analysis/tensor_buffer_test_graph.py @@ -0,0 +1,72 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensor_buffer in graph mode.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from privacy.analysis import tensor_buffer + + +class TensorBufferTest(tf.test.TestCase): + """Tests for TensorBuffer in graph mode.""" + + def test_noresize(self): + """Test buffer does not resize if capacity is not exceeded.""" + with self.cached_session() as sess: + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + value1 = [[1, 2, 3], [4, 5, 6]] + with tf.control_dependencies([my_buffer.append(value1)]): + value2 = [[7, 8, 9], [10, 11, 12]] + with tf.control_dependencies([my_buffer.append(value2)]): + values = my_buffer.values + current_size = my_buffer.current_size + capacity = my_buffer.capacity + self.evaluate(tf.global_variables_initializer()) + + v, cs, cap = sess.run([values, current_size, capacity]) + self.assertAllEqual(v, [value1, value2]) + self.assertEqual(cs, 2) + self.assertEqual(cap, 2) + + def test_resize(self): + """Test buffer resizes if capacity is exceeded.""" + with self.cached_session() as sess: + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + value1 = [[1, 2, 3], [4, 5, 6]] + with tf.control_dependencies([my_buffer.append(value1)]): + value2 = [[7, 8, 9], [10, 11, 12]] + with tf.control_dependencies([my_buffer.append(value2)]): + value3 = [[13, 14, 15], [16, 17, 18]] + with tf.control_dependencies([my_buffer.append(value3)]): + values = my_buffer.values + current_size = my_buffer.current_size + capacity = my_buffer.capacity + self.evaluate(tf.global_variables_initializer()) + + v, cs, cap = sess.run([values, current_size, capacity]) + self.assertAllEqual(v, [value1, value2, value3]) + self.assertEqual(cs, 3) + self.assertEqual(cap, 4) + + +if __name__ == '__main__': + tf.test.main()