diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 71593ef..20fe19a 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ -from typing import Dict, Iterable, Optional, Text, Union +from typing import Any, Callable, Dict, Iterable, Optional, Text, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils @@ -67,6 +67,7 @@ def compute_gradient_norms( x_batch: InputTensor, y_batch: tf.Tensor, layer_registry: lr.LayerRegistry, + per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None, num_microbatches: Optional[lr.BatchSize] = None, ): """Computes the per-example loss gradient norms for given data. @@ -86,6 +87,9 @@ def compute_gradient_norms( compute gradient norms quickly. See `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for more details. + per_example_loss_fn: If not None, used as the function to compute the + vectorized per example loss. Otherwise, we derive it from `input_model`'s + loss function. num_microbatches: An optional number or scalar `tf.Tensor` for the number of microbatches. If not None, indicates that the loss is grouped into num_microbatches (in this case, the batch dimension needs to be a multiple @@ -109,9 +113,10 @@ def compute_gradient_norms( ) ) # Ignore the original loss function's reduction to get per-example loss. - loss_config = input_model.loss.get_config() - loss_config['reduction'] = tf.keras.losses.Reduction.NONE - per_example_loss_fn = input_model.loss.from_config(loss_config) + if per_example_loss_fn is None: + loss_config = input_model.loss.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + per_example_loss_fn = input_model.loss.from_config(loss_config) losses = per_example_loss_fn(y_batch, model_outputs) if num_microbatches is not None: losses = tf.reduce_mean( @@ -204,7 +209,11 @@ def compute_pred_and_clipped_gradients( gradient of the loss function. """ gradient_norms = compute_gradient_norms( - input_model, x_batch, y_batch, layer_registry, num_microbatches + input_model, + x_batch, + y_batch, + layer_registry, + num_microbatches=num_microbatches, ) loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) with tf.GradientTape() as tape: diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 2d1dad8..ff21c7c 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -71,16 +71,25 @@ def double_dense_layer_computation( return [vars1, vars2], outputs, sqr_norm_fn +def test_loss_fn(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: + x = tf.reshape(x, (tf.shape(x)[0], -1)) + y = tf.reshape(y, (tf.shape(y)[0], -1)) + # Define a loss function which is unlikely to be coincidently defined. + return 3.14 * tf.reduce_sum(tf.square(x - y), axis=1) + + def compute_true_gradient_norms( input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor, + per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], num_microbatches: Optional[int], ) -> layer_registry.OutputTensor: """Computes the real gradient norms for an input `(model, x, y)`.""" - loss_config = input_model.loss.get_config() - loss_config['reduction'] = tf.keras.losses.Reduction.NONE - per_example_loss_fn = input_model.loss.from_config(loss_config) + if per_example_loss_fn is None: + loss_config = input_model.loss.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + per_example_loss_fn = input_model.loss.from_config(loss_config) with tf.GradientTape(persistent=True) as tape: y_pred = input_model(x_batch) loss = per_example_loss_fn(y_batch, y_pred) @@ -109,6 +118,7 @@ def get_computed_and_true_norms( layer_generator: LayerGenerator, input_dims: Union[int, List[int]], output_dim: int, + per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], num_microbatches: Optional[int], is_eager: bool, x_input: tf.Tensor, @@ -130,6 +140,8 @@ def get_computed_and_true_norms( `idim` and returns output tensors of dimension `odim`. input_dims: The input dimension(s) of the test `tf.keras.Model` instance. output_dim: The output dimension of the test `tf.keras.Model` instance. + per_example_loss_fn: If not None, used as vectorized per example loss + function. num_microbatches: The number of microbatches. None or an integer. is_eager: A `bool` that is `True` if the model should be run eagerly. x_input: `tf.Tensor` inputs to be tested. @@ -159,11 +171,12 @@ def get_computed_and_true_norms( x_input, y_batch, layer_registry=registry, + per_example_loss_fn=per_example_loss_fn, num_microbatches=num_microbatches, ) tf.keras.utils.set_random_seed(rng_seed) true_norms = compute_true_gradient_norms( - model, x_input, y_batch, num_microbatches + model, x_input, y_batch, per_example_loss_fn, num_microbatches ) return (computed_norms, true_norms) @@ -348,6 +361,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): layer_name=list(get_dense_layer_generators().keys()), input_dim=[4], output_dim=[1, 2], + per_example_loss_fn=[None, test_loss_fn], num_microbatches=[None, 1, 2], is_eager=[True, False], ) @@ -357,6 +371,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): layer_name, input_dim, output_dim, + per_example_loss_fn, num_microbatches, is_eager, ): @@ -379,6 +394,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): layer_generator, input_dim, output_dim, + per_example_loss_fn, num_microbatches, is_eager, x_input, @@ -419,11 +435,18 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): ], model_name=list(get_embedding_model_generators().keys()), output_dim=[2], + per_example_loss_fn=[None, test_loss_fn], num_microbatches=[None, 1, 2], is_eager=[True], ) def test_gradient_norms_on_various_models( - self, x_batch, model_name, output_dim, num_microbatches, is_eager + self, + x_batch, + model_name, + output_dim, + per_example_loss_fn, + num_microbatches, + is_eager, ): if ( num_microbatches is not None @@ -442,6 +465,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): layer_generator=None, input_dims=x_batch.shape[1:], output_dim=output_dim, + per_example_loss_fn=per_example_loss_fn, num_microbatches=num_microbatches, is_eager=is_eager, x_input=x_batch, @@ -455,11 +479,17 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): @parameterized.product( input_dim=[1, 2], output_dim=[1, 2], + per_example_loss_fn=[None, test_loss_fn], num_microbatches=[None, 1, 2], is_eager=[True, False], ) def test_gradient_norms_on_various_models( - self, input_dim, output_dim, num_microbatches, is_eager + self, + input_dim, + output_dim, + per_example_loss_fn, + num_microbatches, + is_eager, ): registry = layer_registry.make_default_layer_registry() registry.insert(DoubleDense, double_dense_layer_computation) @@ -475,6 +505,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): layer_generator=lambda a, b: DoubleDense(b), input_dims=input_dim, output_dim=output_dim, + per_example_loss_fn=per_example_loss_fn, num_microbatches=num_microbatches, is_eager=is_eager, x_input=x_batch,