diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index dd24122..3b88819 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -35,8 +35,10 @@ py_library( py_test( name = "clip_grads_test", + size = "large", srcs = ["clip_grads_test.py"], python_version = "PY3", + shard_count = 8, srcs_version = "PY3", deps = [ ":clip_grads", diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 23f1b0c..0c37d7f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -54,7 +54,11 @@ def get_registry_generator_fn( (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( layer_instance, args, kwargs, tape, num_microbatches ) - return layer_outputs, (layer_vars, layer_sqr_norm_fn) + return layer_outputs, ( + layer_vars, + layer_sqr_norm_fn, + layer_instance.trainable_weights, + ) else: # Non-trainable layer. return layer_instance(*args, **kwargs), None @@ -69,6 +73,7 @@ def compute_gradient_norms( layer_registry: lr.LayerRegistry, per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None, num_microbatches: Optional[lr.BatchSize] = None, + trainable_vars: Optional[List[tf.Variable]] = None, ): """Computes the per-example loss gradient norms for given data. @@ -96,6 +101,10 @@ def compute_gradient_norms( of num_microbatches). When there is microbatches, we always assume the loss is the mean over a microbatch. And the gradient norm is computed for each microbatch. + trainable_vars: The list of variables included in computing the gradient + norm. When a layer has multiple variables, we include all the variables if + any of the variables is in the list. If `trainable_vars` is None, all the + variables are included. Returns: A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th @@ -126,8 +135,19 @@ def compute_gradient_norms( # Unwrap the generator outputs so that the next loop avoids duplicating # backprop ops. filtered_outputs = [t for t in generator_outputs_list if t is not None] - vars_list = [a for (a, b) in filtered_outputs] - sqr_norm_fns_list = [b for (a, b) in filtered_outputs] + vars_list = [] + sqr_norm_fns_list = [] + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) + for v, f, weights_list in filtered_outputs: + if trainable_vars is None or any( + w.ref() in trainable_vars for w in weights_list + ): + # Include only those variables in trainable_vars. + vars_list.append(v) + sqr_norm_fns_list.append(f) # Second loop evaluates the squared L2 norm functions and appends the results. grads_list = tape.gradient( summed_loss, @@ -218,6 +238,7 @@ def compute_clipped_gradients_and_outputs( y_batch, layer_registry, num_microbatches=num_microbatches, + trainable_vars=input_model.trainable_variables, ) loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) with tf.GradientTape() as tape: diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index ff21c7c..a0986fc 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -84,6 +84,7 @@ def compute_true_gradient_norms( y_batch: tf.Tensor, per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], num_microbatches: Optional[int], + trainable_vars: Optional[tf.Variable] = None, ) -> layer_registry.OutputTensor: """Computes the real gradient norms for an input `(model, x, y)`.""" if per_example_loss_fn is None: @@ -104,7 +105,8 @@ def compute_true_gradient_norms( if isinstance(loss, tf.RaggedTensor): loss = loss.to_tensor() sqr_norms = [] - for var in input_model.trainable_variables: + trainable_vars = trainable_vars or input_model.trainable_variables + for var in trainable_vars: jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) reduction_axes = tf.range(1, len(jacobian.shape)) sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes) @@ -124,6 +126,7 @@ def get_computed_and_true_norms( x_input: tf.Tensor, rng_seed: int = 777, registry: layer_registry.LayerRegistry = None, + partial: bool = False, ) -> Tuple[tf.Tensor, tf.Tensor]: """Obtains the true and computed gradient norms for a model and batch input. @@ -147,6 +150,8 @@ def get_computed_and_true_norms( x_input: `tf.Tensor` inputs to be tested. rng_seed: An `int` used to initialize model weights. registry: A `layer_registry.LayerRegistry` instance. + partial: Whether to compute the gradient norm with respect to a partial set + of varibles. If True, only consider the variables in the first layer. Returns: A `tuple` `(computed_norm, true_norms)`. The first element contains the @@ -163,6 +168,13 @@ def get_computed_and_true_norms( ), run_eagerly=is_eager, ) + trainable_vars = None + if partial: + # Gets the first layer with variables. + for l in model.layers: + trainable_vars = l.trainable_variables + if trainable_vars: + break y_pred = model(x_input) y_batch = tf.ones_like(y_pred) tf.keras.utils.set_random_seed(rng_seed) @@ -173,10 +185,16 @@ def get_computed_and_true_norms( layer_registry=registry, per_example_loss_fn=per_example_loss_fn, num_microbatches=num_microbatches, + trainable_vars=trainable_vars, ) tf.keras.utils.set_random_seed(rng_seed) true_norms = compute_true_gradient_norms( - model, x_input, y_batch, per_example_loss_fn, num_microbatches + model, + x_input, + y_batch, + per_example_loss_fn, + num_microbatches, + trainable_vars=trainable_vars, ) return (computed_norms, true_norms) @@ -360,10 +378,11 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): model_name=list(get_dense_model_generators().keys()), layer_name=list(get_dense_layer_generators().keys()), input_dim=[4], - output_dim=[1, 2], + output_dim=[2], per_example_loss_fn=[None, test_loss_fn], num_microbatches=[None, 1, 2], is_eager=[True, False], + partial=[True, False], ) def test_gradient_norms_on_various_models( self, @@ -374,6 +393,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): per_example_loss_fn, num_microbatches, is_eager, + partial, ): model_generator = get_dense_model_generators()[model_name] layer_generator = get_dense_layer_generators()[layer_name] @@ -399,6 +419,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): is_eager, x_input, registry=default_registry, + partial=partial, ) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) @@ -436,8 +457,9 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): model_name=list(get_embedding_model_generators().keys()), output_dim=[2], per_example_loss_fn=[None, test_loss_fn], - num_microbatches=[None, 1, 2], - is_eager=[True], + num_microbatches=[None, 2], + is_eager=[True, False], + partial=[True, False], ) def test_gradient_norms_on_various_models( self, @@ -447,6 +469,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): per_example_loss_fn, num_microbatches, is_eager, + partial, ): if ( num_microbatches is not None @@ -470,6 +493,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): is_eager=is_eager, x_input=x_batch, registry=default_registry, + partial=partial, ) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) @@ -477,11 +501,12 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): @parameterized.product( - input_dim=[1, 2], - output_dim=[1, 2], + input_dim=[3], + output_dim=[2], per_example_loss_fn=[None, test_loss_fn], - num_microbatches=[None, 1, 2], + num_microbatches=[None, 2], is_eager=[True, False], + partial=[True, False], ) def test_gradient_norms_on_various_models( self, @@ -490,6 +515,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): per_example_loss_fn, num_microbatches, is_eager, + partial, ): registry = layer_registry.make_default_layer_registry() registry.insert(DoubleDense, double_dense_layer_computation) @@ -510,6 +536,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): is_eager=is_eager, x_input=x_batch, registry=registry, + partial=partial, ) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)