From 755ed26671f5567ba1519a4e80078eded7a6299b Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Tue, 20 Apr 2021 12:35:00 -0700 Subject: [PATCH] Update keras optimizers (both traditional and vectorized) to handle case of num_microbatches=None. PiperOrigin-RevId: 369497296 --- .../privacy/optimizers/dp_optimizer_keras.py | 16 +++++++++++----- .../optimizers/dp_optimizer_keras_test.py | 8 ++++++++ .../optimizers/dp_optimizer_keras_vectorized.py | 16 +++++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py index 2672f38..eac2916 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py @@ -87,15 +87,20 @@ def make_keras_optimizer_class(cls): if not callable(var_list): tape.watch(var_list) - if callable(loss): - loss = loss() - microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + loss = loss() + batch_size = tf.shape(input=loss)[0] + if self._num_microbatches is None: + self._num_microbatches = batch_size + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [self._num_microbatches, -1]), axis=1) if callable(var_list): var_list = var_list() else: with tape: + batch_size = tf.shape(input=loss)[0] + if self._num_microbatches is None: + self._num_microbatches = batch_size microbatch_losses = tf.reduce_mean( tf.reshape(loss, [self._num_microbatches, -1]), axis=1) @@ -122,7 +127,8 @@ def make_keras_optimizer_class(cls): noised_gradient = tf.add(summed_gradient, noise) # Normalize by number of microbatches and return. - return tf.truediv(noised_gradient, self._num_microbatches) + return tf.truediv(noised_gradient, + tf.cast(self._num_microbatches, tf.float32)) final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, clipped_gradients) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py index 317e6d3..71b68c6 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py @@ -51,6 +51,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): ('DPAdagradVectorized 4', dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized None', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None, + [-2.5, -2.5], [-0.5]), ) def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0, expected_grad1): @@ -89,6 +92,9 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): ('DPAdagradVectorized 4', dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized None', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None, + [-2.5, -2.5], [-0.5]), ) def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0, expected_grad1): @@ -244,6 +250,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2), ('DPGradientDescentVectorized 4', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4), + ('DPGradientDescentVectorized None', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None), ) def testBaseline(self, cls, num_microbatches): """Tests that DP optimizers work with tf.estimator.""" diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py index c51572a..e1d3fc6 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py @@ -100,15 +100,20 @@ def make_vectorized_keras_optimizer_class(cls): if not callable(var_list): tape.watch(var_list) - if callable(loss): - loss = loss() - microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + loss = loss() + batch_size = tf.shape(input=loss)[0] + if self._num_microbatches is None: + self._num_microbatches = batch_size + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [self._num_microbatches, -1]), axis=1) if callable(var_list): var_list = var_list() else: with tape: + batch_size = tf.shape(input=loss)[0] + if self._num_microbatches is None: + self._num_microbatches = batch_size microbatch_losses = tf.reduce_mean( tf.reshape(loss, [self._num_microbatches, -1]), axis=1) @@ -132,7 +137,8 @@ def make_vectorized_keras_optimizer_class(cls): noised_gradient = tf.add(summed_gradient, noise) # Normalize by number of microbatches and return. - return tf.truediv(noised_gradient, self._num_microbatches) + return tf.truediv(noised_gradient, + tf.cast(self._num_microbatches, tf.float32)) final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, clipped_gradients)