diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index b2c9dd3..6a37ae3 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -42,12 +42,9 @@ def get_registry_generator_fn(tape, layer_registry): % layer_instance.__class__.__name__ ) registry_fn = layer_registry.lookup(layer_instance) - (layer_vars, transform, layer_sqr_norm_fn) = registry_fn( - layer_instance, args + (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( + layer_instance, args, tape ) - if tape is not None: - tape.watch(layer_vars) - layer_outputs = transform(layer_vars) if transform else layer_vars return layer_outputs, (layer_vars, layer_sqr_norm_fn) else: # Non-trainable layer. diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 6514711..183d890 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -21,8 +21,38 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry # ============================================================================== -# Helper functions. +# Helper functions and classes. # ============================================================================== +class DoubleDense(tf.keras.layers.Layer): + """Generates two dense layers nested together.""" + + def __init__(self, units): + super().__init__() + self.dense1 = tf.keras.layers.Dense(units) + self.dense2 = tf.keras.layers.Dense(1) + + def call(self, inputs): + x = self.dense1(inputs) + return self.dense2(x) + + +def double_dense_layer_computation(layer_instance, inputs, tape): + """Layer registry function for the custom `DoubleDense` layer class.""" + vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( + layer_instance.dense1, inputs, tape + ) + vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation( + layer_instance.dense2, (outputs,), tape + ) + + def sqr_norm_fn(base_vars): + norms1 = sqr_norm_fn1(base_vars[0]) + norms2 = sqr_norm_fn2(base_vars[1]) + return norms1 + norms2 + + return [vars1, vars2], outputs, sqr_norm_fn + + def compute_true_gradient_norms(input_model, x_batch, y_batch): """Computes the real gradient norms for an input `(model, x, y)`.""" loss_config = input_model.loss.get_config() @@ -50,6 +80,7 @@ def get_computed_and_true_norms( is_eager, x_input, rng_seed=777, + registry=None, ): """Obtains the true and computed gradient norms for a model and batch input. @@ -69,6 +100,7 @@ def get_computed_and_true_norms( is_eager: A `bool` that is `True` if the model should be run eagerly. x_input: `tf.Tensor` inputs to be tested. rng_seed: An `int` used to initialize model weights. + registry: A `layer_registry.LayerRegistry` instance. Returns: A `tuple` `(computed_norm, true_norms)`. The first element contains the @@ -87,7 +119,6 @@ def get_computed_and_true_norms( ) y_pred = model(x_input) y_batch = tf.ones_like(y_pred) - registry = layer_registry.make_default_layer_registry() tf.keras.utils.set_random_seed(rng_seed) computed_norms = clip_grads.compute_gradient_norms( model, x_input, y_batch, layer_registry=registry @@ -285,6 +316,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): model_generator = get_dense_model_generators()[model_name] layer_generator = get_dense_layer_generators()[layer_name] x_batches = get_nd_test_batches(input_dim) + default_registry = layer_registry.make_default_layer_registry() for x_batch in x_batches: if model_name == 'tower1': x_input = [x_batch, x_batch] @@ -297,6 +329,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): output_dim, is_eager, x_input, + registry=default_registry, ) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) @@ -335,6 +368,7 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): and model_name == 'weighted_bow1' ) or (model_name != 'weighted_bow1') if valid_test_input: + default_registry = layer_registry.make_default_layer_registry() model_generator = get_embedding_model_generators()[model_name] (computed_norms, true_norms) = get_computed_and_true_norms( model_generator=model_generator, @@ -343,6 +377,33 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): output_dim=output_dim, is_eager=is_eager, x_input=x_batch, + registry=default_registry, + ) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + + +class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + input_dim=[1, 2], + output_dim=[1, 2], + is_eager=[True, False], + ) + def test_gradient_norms_on_various_models( + self, input_dim, output_dim, is_eager + ): + registry = layer_registry.make_default_layer_registry() + registry.insert(DoubleDense, double_dense_layer_computation) + x_batches = get_nd_test_batches(input_dim) + for x_batch in x_batches: + (computed_norms, true_norms) = get_computed_and_true_norms( + model_generator=make_two_layer_sequential_model, + layer_generator=lambda a, b: DoubleDense(b), + input_dims=input_dim, + output_dim=output_dim, + is_eager=is_eager, + x_input=x_batch, + registry=registry, ) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index 12fa53f..c8279ba 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -72,7 +72,7 @@ class LayerRegistry: # ============================================================================== # Supported Keras layers # ============================================================================== -def dense_layer_computation(layer_instance, inputs): +def dense_layer_computation(layer_instance, inputs, tape): """Registry function for `tf.keras.layers.Dense`. The logic for this computation is based on the following paper: @@ -85,23 +85,27 @@ def dense_layer_computation(layer_instance, inputs): layer_instance: A `tf.keras.layers.Dense` instance. inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., `layer_instance(inputs)` returns a valid output. + tape: A `tf.GradientTape` instance that will be used to watch the output + `base_vars`. Returns: - A `tuple` `(base_vars, transform, sqr_norm_fn)`. `base_vars` is the + A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the intermediate Tensor used in the chain-rule / "fast" clipping trick, - `transform` is a function that maps `base_vars` to the layer outputs, and - `sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that - represents the output of the call `tape.gradient(summed_loss, base_vars)` - where `tape` is a `tf.GradientTape` instance that records the dense - layer computation and `summed_loss` is the sum of the per-example losses - of the underlying model. This function then returns the per-example squared - L2 gradient norms of the trainable variables in `layer_instance`. These - squared norms should be a 1D `tf.Tensor` of length `batch_size`. + `outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is + a function that takes one input, a `tf.Tensor` that represents the output + of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a + `tf.GradientTape` instance that records the dense layer computation and + `summed_loss` is the sum of the per-example losses of the underlying model. + This function then returns the per-example squared L2 gradient norms of the + trainable variables in `layer_instance`. These squared norms should be a 1D + `tf.Tensor` of length `batch_size`. """ orig_activation = layer_instance.activation layer_instance.activation = None base_vars = layer_instance(*inputs) + tape.watch(base_vars) layer_instance.activation = orig_activation + outputs = orig_activation(base_vars) if orig_activation else base_vars def sqr_norm_fn(base_vars_grads): sqr_inputs = tf.square(*inputs) inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs)) @@ -118,10 +122,10 @@ def dense_layer_computation(layer_instance, inputs): ) return input_sqr_norms * base_vars_sqr_norms - return base_vars, layer_instance.activation, sqr_norm_fn + return base_vars, outputs, sqr_norm_fn -def embedding_layer_computation(layer_instance, inputs): +def embedding_layer_computation(layer_instance, inputs, tape): """Registry function for `tf.keras.layers.Embedding`. The logic of this computation is based on the `tf.keras.layers.Dense` @@ -134,17 +138,20 @@ def embedding_layer_computation(layer_instance, inputs): layer_instance: A `tf.keras.layers.Embedding` instance. inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., `layer_instance(inputs)` returns a valid output. + tape: A `tf.GradientTape` instance that will be used to watch the output + `base_vars`. Returns: - A `tuple` `(base_vars, transform, sqr_norm_fn)`, `base_vars` is the + A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the intermediate Tensor used in the chain-rule / "fast" clipping trick, - `sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that - represents the output of the call `tape.gradient(summed_loss, base_vars)` - where `tape` is a `tf.GradientTape` instance that records the dense - layer computation and `summed_loss` is the sum of the per-example losses - of the underlying model. This function then returns the per-example squared - L2 gradient norms of the trainable variables in `layer_instance`. These - squared norms should be a 1D `tf.Tensor` of length `batch_size`. + `outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is + a function that takes one input, a `tf.Tensor` that represents the output + of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a + `tf.GradientTape` instance that records the dense layer computation and + `summed_loss` is the sum of the per-example losses of the underlying model. + This function then returns the per-example squared L2 gradient norms of the + trainable variables in `layer_instance`. These squared norms should be a 1D + `tf.Tensor` of length `batch_size`. """ if hasattr(layer_instance, "sparse"): # for backwards compatibility if layer_instance.sparse: @@ -161,9 +168,8 @@ def embedding_layer_computation(layer_instance, inputs): ) input_ids = tf.cast(*inputs, tf.int32) base_vars = layer_instance.trainable_variables[0] - - def lookup_inputs(embeddings): - return tf.nn.embedding_lookup(embeddings, input_ids) + tape.watch(base_vars) + outputs = tf.nn.embedding_lookup(base_vars, input_ids) def sqr_norm_fn(base_vars_grads): # Get a 1D tensor of the row indices. @@ -213,7 +219,7 @@ def embedding_layer_computation(layer_instance, inputs): num_segments=nrows, ) # fill in empty inputs - return base_vars, lookup_inputs, sqr_norm_fn + return base_vars, outputs, sqr_norm_fn # ==============================================================================