From 6460c3feb803c285cd9fb8464d558e25f4af4361 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Mon, 21 Dec 2020 17:06:32 -0800 Subject: [PATCH] Vectorized version of DP Keras optimizers. PiperOrigin-RevId: 348551659 --- .../optimizers/dp_optimizer_keras_test.py | 75 ++++++- .../dp_optimizer_keras_vectorized.py | 188 ++++++++++++++++++ 2 files changed, 256 insertions(+), 7 deletions(-) create mode 100644 tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py index 77b7b6a..317e6d3 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py @@ -22,6 +22,7 @@ import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras +from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras_vectorized class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): @@ -37,10 +38,19 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1, [-2.5, -2.5], [-0.5]), - ('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, - [-2.5, -2.5], [-0.5]), + ('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, [-2.5, -2.5 + ], [-0.5]), ('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4, [-2.5, -2.5], [-0.5]), + ('DPGradientDescentVectorized 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, + [-2.5, -2.5], [-0.5]), + ('DPAdamVectorized 2', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2, + [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized 4', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, + [-2.5, -2.5], [-0.5]), ) def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0, expected_grad1): @@ -70,6 +80,15 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): ], [-0.5]), ('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4, [-2.5, -2.5], [-0.5]), + ('DPGradientDescentVectorized 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, + [-2.5, -2.5], [-0.5]), + ('DPAdamVectorized 2', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2, + [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized 4', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, + [-2.5, -2.5], [-0.5]), ) def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0, expected_grad1): @@ -94,7 +113,10 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0]) @parameterized.named_parameters( - ('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer),) + ('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer), + ('DPGradientDescentVectorized', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer), + ) def testClippingNorm(self, cls): var0 = tf.Variable([0.0, 0.0]) data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) @@ -115,6 +137,12 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): 4.0, 1), ('DPGradientDescent 4 1 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0, 1.0, 4), + ('DPGradientDescentVectorized 2 4 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, + 1), + ('DPGradientDescentVectorized 4 1 4', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4.0, 1.0, + 4), ) def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, num_microbatches): @@ -138,7 +166,14 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer), ('DPAdagrad', dp_optimizer_keras.DPKerasAdagradOptimizer), - ('DPAdam', dp_optimizer_keras.DPKerasAdamOptimizer)) + ('DPAdam', dp_optimizer_keras.DPKerasAdamOptimizer), + ('DPGradientDescentVectorized', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer), + ('DPAdagradVectorized', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer), + ('DPAdamVectorized', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer), + ) def testAssertOnNoCallOfComputeGradients(self, cls): """Tests that assertion fails when DP gradients are not computed.""" opt = cls( @@ -202,7 +237,14 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), ('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), - ('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),) + ('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), + ('DPGradientDescentVectorized 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), + ('DPGradientDescentVectorized 2', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2), + ('DPGradientDescentVectorized 4', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4), + ) def testBaseline(self, cls, num_microbatches): """Tests that DP optimizers work with tf.estimator.""" @@ -233,7 +275,10 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): # Parameters for testing: optimizer, num_microbatches. @parameterized.named_parameters( - ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),) + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), + ('DPGradientDescentVectorized 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), + ) def testClippingNorm(self, cls, num_microbatches): """Tests that DP optimizers work with tf.estimator.""" @@ -279,6 +324,15 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): 2.0, 4), ('DPGradientDescent 8 6 8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0, 6.0, 8), + ('DPGradientDescentVectorized 2 4 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, + 1), + ('DPGradientDescentVectorized 3 2 4', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0, + 4), + ('DPGradientDescentVectorized 8 6 8', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0, + 8), ) def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, num_microbatches): @@ -312,7 +366,14 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer), ('DPAdagrad', dp_optimizer_keras.DPKerasAdagradOptimizer), - ('DPAdam', dp_optimizer_keras.DPKerasAdamOptimizer)) + ('DPAdam', dp_optimizer_keras.DPKerasAdamOptimizer), + ('DPGradientDescentVectorized', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer), + ('DPAdagradVectorized', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer), + ('DPAdamVectorized', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer), + ) def testAssertOnNoCallOfGetGradients(self, cls): """Tests that assertion fails when DP gradients are not computed.""" opt = cls( diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py new file mode 100644 index 0000000..9be6b31 --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py @@ -0,0 +1,188 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentially private version of Keras optimizer v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.dp_query import gaussian_query + + +def clip_gradients_vmap(g, l2_norm_clip): + """Clips gradients in a way that is compatible with vectorized_map.""" + grads_flat = tf.nest.flatten(g) + squared_l2_norms = [ + tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat + ] + global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) + div = tf.maximum(global_norm / l2_norm_clip, 1.) + clipped_flat = [g / div for g in grads_flat] + clipped_grads = tf.nest.pack_sequence_as(g, clipped_flat) + return clipped_grads + + +def make_vectorized_keras_optimizer_class(cls): + """Constructs a DP Keras optimizer class from an existing one.""" + + class DPOptimizerClass(cls): + """Differentially private subclass of given class cls. + + The class tf.keras.optimizers.Optimizer has two methods to compute + gradients, `_compute_gradients` and `get_gradients`. The first works + with eager execution, while the second runs in graph mode and is used + by canned estimators. + + Internally, DPOptimizerClass stores hyperparameters both individually + and encapsulated in a `GaussianSumQuery` object for these two use cases. + However, this should be invisible to users of this class. + """ + + def __init__( + self, + l2_norm_clip, + noise_multiplier, + num_microbatches=None, + *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args + **kwargs): + """Initialize the DPOptimizerClass. + + Args: + l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients) + noise_multiplier: Ratio of the standard deviation to the clipping norm + num_microbatches: The number of microbatches into which each minibatch + is split. + """ + super(DPOptimizerClass, self).__init__(*args, **kwargs) + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + self._num_microbatches = num_microbatches + self._dp_sum_query = gaussian_query.GaussianSumQuery( + l2_norm_clip, l2_norm_clip * noise_multiplier) + self._global_state = None + self._was_dp_gradients_called = False + + def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): + """DP version of superclass method.""" + + self._was_dp_gradients_called = True + # Compute loss. + if not callable(loss) and tape is None: + raise ValueError('`tape` is required when a `Tensor` loss is passed.') + tape = tape if tape is not None else tf.GradientTape() + + if callable(loss): + with tape: + if not callable(var_list): + tape.watch(var_list) + + if callable(loss): + loss = loss() + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + + if callable(var_list): + var_list = var_list() + else: + with tape: + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + + var_list = tf.nest.flatten(var_list) + + # Compute the per-microbatch losses using helpful jacobian method. + with tf.keras.backend.name_scope(self._name + '/gradients'): + jacobian = tape.jacobian(microbatch_losses, var_list) + + clipped_gradients = tf.vectorized_map( + lambda g: clip_gradients_vmap(g, self._l2_norm_clip), jacobian) + + def reduce_noise_normalize_batch(g): + # Sum gradients over all microbatches. + summed_gradient = tf.reduce_sum(g, axis=0) + + # Add noise to summed gradients. + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal( + tf.shape(input=summed_gradient), stddev=noise_stddev) + noised_gradient = tf.add(summed_gradient, noise) + + # Normalize by number of microbatches and return. + return tf.truediv(noised_gradient, self._num_microbatches) + + final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, + clipped_gradients) + + return list(zip(final_gradients, var_list)) + + def get_gradients(self, loss, params): + """DP version of superclass method.""" + + self._was_dp_gradients_called = True + if self._global_state is None: + self._global_state = self._dp_sum_query.initial_global_state() + + batch_size = tf.shape(input=loss)[0] + if self._num_microbatches is None: + self._num_microbatches = batch_size + + microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1]) + + def process_microbatch(microbatch_loss): + """Compute clipped grads for one microbatch.""" + mean_loss = tf.reduce_mean(input_tensor=microbatch_loss) + grads = super(DPOptimizerClass, self).get_gradients(mean_loss, params) + grads_list = [ + g if g is not None else tf.zeros_like(v) + for (g, v) in zip(list(grads), params) + ] + clipped_grads = clip_gradients_vmap(grads_list, self._l2_norm_clip) + return clipped_grads + + clipped_grads = tf.vectorized_map(process_microbatch, microbatch_losses) + + def reduce_noise_normalize_batch(stacked_grads): + summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0) + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal( + tf.shape(input=summed_grads), stddev=noise_stddev) + noised_grads = summed_grads + noise + return noised_grads / tf.cast(self._num_microbatches, tf.float32) + + final_grads = tf.nest.map_structure(reduce_noise_normalize_batch, + clipped_grads) + return final_grads + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + assert self._was_dp_gradients_called, ( + 'Neither _compute_gradients() or get_gradients() on the ' + 'differentially private optimizer was called. This means the ' + 'training is not differentially private. It may be the case that ' + 'you need to upgrade to TF 2.4 or higher to use this particular ' + 'optimizer.') + return super(DPOptimizerClass, + self).apply_gradients(grads_and_vars, global_step, name) + + return DPOptimizerClass + + +VectorizedDPKerasAdagradOptimizer = make_vectorized_keras_optimizer_class( + tf.keras.optimizers.Adagrad) +VectorizedDPKerasAdamOptimizer = make_vectorized_keras_optimizer_class( + tf.keras.optimizers.Adam) +VectorizedDPKerasSGDOptimizer = make_vectorized_keras_optimizer_class( + tf.keras.optimizers.SGD)