From 44dfac37708f6b48edbe4348777d0fad2112e2c8 Mon Sep 17 00:00:00 2001 From: William Kong Date: Tue, 16 Apr 2024 11:18:49 -0700 Subject: [PATCH] Implement fast gradient clipping for loss functions that use inputs that are fed into shared weights. PiperOrigin-RevId: 625395017 --- .../privacy/fast_gradient_clipping/BUILD | 1 + .../fast_gradient_clipping/clip_grads.py | 43 +++++--- .../fast_gradient_clipping/clip_grads_test.py | 103 ++++++++++++++++++ 3 files changed, 134 insertions(+), 13 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index abdcb6d..ac3e47d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -87,6 +87,7 @@ py_test( name = "clip_grads_test", srcs = ["clip_grads_test.py"], python_version = "PY3", + shard_count = 8, srcs_version = "PY3", deps = [ ":clip_grads", diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index feda803..9a5a96e 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,6 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ +import collections from collections.abc import Sequence from typing import Optional @@ -56,6 +57,7 @@ def get_registry_generator_fn( layer_instance, args, kwargs, tape, num_microbatches ) return layer_outputs, ( + str(id(layer_instance)), layer_vars, layer_sqr_norm_fn, layer_instance.trainable_weights, @@ -156,32 +158,47 @@ def compute_gradient_norms( # Unwrap the generator outputs so that the next loop avoids duplicating # backprop ops. filtered_outputs = [t for t in generator_outputs_list if t is not None] - vars_list = [] - sqr_norm_fns_list = [] if trainable_vars is not None: # Create a set using `ref()` for fast set membership check. tf.Variable # itself is not hashable. trainable_vars = set([v.ref() for v in trainable_vars]) - for v, f, weights_list in filtered_outputs: + layer_vars = collections.defaultdict(list) + layer_sqr_norm_fns = collections.defaultdict(list) + # The case of shared weights: + # If a layer is called k times, it will appear k times in filtered_outputs, + # with the same id, but potentially with different v and f. The code below + # groups filtered_outputs by layer_id, so we can correctly compute gradient + # norms. The gradient norm of a layer that occurs k times is computed as + # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th + # occurrence. This is an over-estimate of the actual norm. For more details, + # see the explanation in go/dp-sgd-shared-weights. + for layer_id, v, f, weights_list in filtered_outputs: if trainable_vars is None or any( w.ref() in trainable_vars for w in weights_list ): - # Include only those variables in trainable_vars. - vars_list.append(v) - sqr_norm_fns_list.append(f) + layer_vars[layer_id].append(v) + layer_sqr_norm_fns[layer_id].append(f) # Second loop evaluates the squared L2 norm functions and appends the results. - grads_list = tape.gradient( + layer_grad_vars = tape.gradient( summed_loss, - vars_list, + layer_vars, unconnected_gradients=tf.UnconnectedGradients.ZERO, ) - if not grads_list: + if not layer_grad_vars: raise ValueError('The gradient list cannot be empty.') - if len(grads_list) != len(sqr_norm_fns_list): - raise ValueError('There must be as many norms as gradients.') sqr_norm_list = [] - for grads, f in zip(grads_list, sqr_norm_fns_list): - sqr_norm_list.append(f(grads)) + for layer_id in layer_sqr_norm_fns.keys(): + fns = layer_sqr_norm_fns[layer_id] + grads = layer_grad_vars[layer_id] + # Number of duplicates for this layer in `filtered_outputs`. + num_passes = len(fns) + if len(fns) != len(grads): + raise ValueError( + 'There must be as many gradients as squared norm functions.' + ) + # See go/dp-sgd-shared-weights for more details. + for fn, grad in zip(fns, grads): + sqr_norm_list.append(num_passes * fn(grad)) sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index fe95362..a028aff 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -197,5 +197,108 @@ class ComputeClippedGradsAndOutputsTest( self.assertAlmostEqual(computed_norm, true_norm) +class SharedLayerTest(tf.test.TestCase, parameterized.TestCase): + + def _make_shared_model(self, num_inputs, input_dim): + base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)]) + inputs = [] + outputs = [] + for _ in range(num_inputs): + input_tensor = tf.keras.Input(shape=[input_dim]) + inputs.append(input_tensor) + output_tensor = base_model(input_tensor) + outputs.append(output_tensor) + return tf.keras.Model(inputs=inputs, outputs=tf.add_n(outputs)) + + def _get_computed_and_true_norms(self, model, x_batch, y_batch, is_eager): + model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction='none'), + run_eagerly=is_eager, + ) + computed_norms = clip_grads.compute_gradient_norms( + model, layer_registry.make_default_layer_registry(), x_batch, y_batch + ) + with tf.GradientTape() as tape: + y_pred = model(x_batch) + loss_value = model.loss(y_pred, y_batch) + true_grads = tape.jacobian(loss_value, model.trainable_variables) + true_norms = tf.sqrt( + tf.add_n([tf.reduce_sum(tf.square(g), axis=[1, 2]) for g in true_grads]) + ) + return computed_norms, true_norms + + @parameterized.product( + num_inputs=[1, 2, 10], + batch_size=[1, 2], + input_dim=[1, 3], + is_eager=[True, False], + ) + def test_gradient_norms_on_multiple_inputs_are_upper_bounded( + self, num_inputs, batch_size, input_dim, is_eager + ): + model = self._make_shared_model(num_inputs, input_dim) + model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction='none'), + run_eagerly=is_eager, + ) + x_batch = [ + float(k + 1) * tf.ones([batch_size, input_dim], dtype=tf.float64) + for k in range(num_inputs) + ] + y_batch = tf.reshape( + 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] + ) + computed_norms, true_norms = self._get_computed_and_true_norms( + model, x_batch, y_batch, is_eager + ) + self.assertAllLessEqual(true_norms - computed_norms, 1e-3) + + @parameterized.product( + num_repeats=[1, 2, 10], + batch_size=[1, 2], + input_dim=[1, 3], + is_eager=[True, False], + ) + def test_gradient_norms_on_single_repeated_input_are_upper_bounded( + self, num_repeats, batch_size, input_dim, is_eager + ): + base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)]) + inputs = tf.keras.layers.Input([input_dim]) + outputs = tf.add_n([base_model(inputs) for _ in range(num_repeats)]) + model = tf.keras.Model(inputs=inputs, outputs=outputs) + x_batch = tf.ones([batch_size, input_dim], dtype=tf.float64) + y_batch = tf.reshape( + 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] + ) + computed_norms, true_norms = self._get_computed_and_true_norms( + model, x_batch, y_batch, is_eager + ) + self.assertAllLessEqual(true_norms - computed_norms, 1e-3) + + @parameterized.product( + batch_size=[1, 2], + input_dim=[1, 3], + is_eager=[True, False], + ) + def test_gradient_norms_on_input_slices_are_upper_bounded( + self, batch_size, input_dim, is_eager + ): + base_model = tf.keras.Sequential([tf.keras.layers.Dense(1, use_bias=False)]) + inputs = tf.keras.layers.Input([input_dim, 2]) + outputs = base_model(inputs[:, :, 0]) + base_model(inputs[:, :, 1]) + model = tf.keras.Model(inputs=inputs, outputs=outputs) + x_batch = tf.reshape( + tf.range(batch_size * input_dim * 2, dtype=tf.float64), + [batch_size, input_dim, -1], + ) + y_batch = tf.reshape( + 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] + ) + computed_norms, true_norms = self._get_computed_and_true_norms( + model, x_batch, y_batch, is_eager + ) + self.assertAllLessEqual(true_norms - computed_norms, 1e-3) + + if __name__ == '__main__': tf.test.main()