From 4e1fc252e4c64132ad6fcd838e93f071f38dedd7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Mar 2023 10:34:20 -0800 Subject: [PATCH] Add a `kwargs` argument to the registry API + small changes to docstrings. This is a forward-looking change that is needed to support more complicated layers, such as `tf.keras.layers.MultiHeadAttention`, which can take `kwargs` as part of their `.call()` method and can generate arbitrary outputs. PiperOrigin-RevId: 514775503 --- .../fast_gradient_clipping/clip_grads.py | 2 +- .../fast_gradient_clipping/clip_grads_test.py | 15 ++-- .../gradient_clipping_utils.py | 6 +- .../fast_gradient_clipping/layer_registry.py | 72 ++++++++++--------- 4 files changed, 51 insertions(+), 44 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 32880af..71593ef 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -52,7 +52,7 @@ def get_registry_generator_fn( ) registry_fn = layer_registry.lookup(layer_instance) (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( - layer_instance, args, tape, num_microbatches + layer_instance, args, kwargs, tape, num_microbatches ) return layer_outputs, (layer_vars, layer_sqr_norm_fn) else: diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 5275b2a..2d1dad8 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -13,7 +13,7 @@ # limitations under the License. import itertools -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union from absl.testing import parameterized import tensorflow as tf @@ -50,16 +50,17 @@ class DoubleDense(tf.keras.layers.Layer): def double_dense_layer_computation( layer_instance: tf.keras.layers.Layer, - inputs: Any, + input_args: Tuple[Any, ...], + input_kwargs: Dict[Text, Any], tape: tf.GradientTape, num_microbatches: Optional[int], -): +) -> layer_registry.RegistryFunctionOutput: """Layer registry function for the custom `DoubleDense` layer class.""" vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( - layer_instance.dense1, inputs, tape, num_microbatches + layer_instance.dense1, input_args, input_kwargs, tape, num_microbatches ) vars2, outputs, sqr_norm_fn2 = layer_registry.dense_layer_computation( - layer_instance.dense2, (outputs,), tape, num_microbatches + layer_instance.dense2, (outputs,), {}, tape, num_microbatches ) def sqr_norm_fn(base_vars): @@ -75,7 +76,7 @@ def compute_true_gradient_norms( x_batch: tf.Tensor, y_batch: tf.Tensor, num_microbatches: Optional[int], -): +) -> layer_registry.OutputTensor: """Computes the real gradient norms for an input `(model, x, y)`.""" loss_config = input_model.loss.get_config() loss_config['reduction'] = tf.keras.losses.Reduction.NONE @@ -113,7 +114,7 @@ def get_computed_and_true_norms( x_input: tf.Tensor, rng_seed: int = 777, registry: layer_registry.LayerRegistry = None, -): +) -> Tuple[tf.Tensor, tf.Tensor]: """Obtains the true and computed gradient norms for a model and batch input. Helpful testing wrapper function used to avoid code duplication. diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index ec9d996..896fc3c 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -120,7 +120,11 @@ def model_forward_pass( layer, args, kwargs ) generator_outputs_list.append(layer_generator_outputs) - args = (node_layer_outputs,) + args = ( + node_layer_outputs + if isinstance(node_layer_outputs, tuple) + else (node_layer_outputs,) + ) kwargs = {} # Update the current dictionary of inputs for the next node. diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index eaa188d..69f499d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -56,17 +56,21 @@ import tensorflow as tf # ============================================================================== # Type aliases # ============================================================================== -SquareNormFunction = Callable[[Any], tf.Tensor] +InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] -RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction] +OutputTensor = Union[tf.Tensor, Iterable[tf.Tensor]] + +BatchSize = Union[int, tf.Tensor] + +SquareNormFunction = Callable[[OutputTensor], tf.Tensor] + +RegistryFunctionOutput = Tuple[Any, OutputTensor, SquareNormFunction] RegistryFunction = Callable[ - [Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput + [Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape], + RegistryFunctionOutput, ] -InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] -BatchSize = Union[int, tf.Tensor] - # ============================================================================== # Main class @@ -134,7 +138,8 @@ def add_microbatch_axis( # ============================================================================== def dense_layer_computation( layer_instance: tf.keras.layers.Dense, - inputs: Tuple[InputTensor], + input_args: Tuple[Any, ...], + input_kwargs: Dict[Text, Any], tape: tf.GradientTape, num_microbatches: Optional[tf.Tensor] = None, ) -> RegistryFunctionOutput: @@ -148,9 +153,12 @@ def dense_layer_computation( Args: layer_instance: A `tf.keras.layers.Dense` instance. - inputs: A tuple containing a single `InputTensor` which can be passed into - the layer instance, i.e., `layer_instance(*inputs)` returns a valid - output. + input_args: A `tuple` containing the first part of `layer_instance` input. + Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return + a valid output. + input_kwargs: A `tuple` containing the second part of `layer_instance` + input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should + return a valid output. tape: A `tf.GradientTape` instance that will be used to watch the output `base_vars`. num_microbatches: An optional numeric value or scalar `tf.Tensor` for @@ -169,11 +177,14 @@ def dense_layer_computation( trainable variables in `layer_instance`. These squared norms should be a 1D `tf.Tensor` of length `batch_size`. """ - if len(inputs) != 1: + if input_kwargs: + raise ValueError("Dense layer calls should not receive kwargs.") + del input_kwargs # Unused in dense layer calls. + if len(input_args) != 1: raise ValueError("Only layer inputs of length 1 are permitted.") orig_activation = layer_instance.activation layer_instance.activation = None - base_vars = layer_instance(*inputs) + base_vars = layer_instance(*input_args) tape.watch(base_vars) layer_instance.activation = orig_activation outputs = orig_activation(base_vars) if orig_activation else base_vars @@ -188,7 +199,7 @@ def dense_layer_computation( # Special handling for better efficiency return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x))) - inputs_gram = _compute_gramian(*inputs) + inputs_gram = _compute_gramian(*input_args) base_vars_grads_gram = _compute_gramian(base_vars_grads) if layer_instance.use_bias: # Adding a bias term is equivalent to a layer with no bias term and which @@ -206,7 +217,8 @@ def dense_layer_computation( def embedding_layer_computation( layer_instance: tf.keras.layers.Embedding, - inputs: Tuple[InputTensor], + input_args: Tuple[Any, ...], + input_kwargs: Dict[Text, Any], tape: tf.GradientTape, num_microbatches: Optional[tf.Tensor] = None, ) -> RegistryFunctionOutput: @@ -220,33 +232,23 @@ def embedding_layer_computation( Args: layer_instance: A `tf.keras.layers.Embedding` instance. - inputs: A tuple containing a single `InputTensor` which can be passed into - the layer instance, i.e., `layer_instance(*inputs)` returns a valid - output. - tape: A `tf.GradientTape` instance that will be used to watch the output - `base_vars`. - num_microbatches: An optional numeric value or scalar `tf.Tensor` for - indicating whether and how the losses are grouped into microbatches. If - not None, num_microbatches must divide the batch size. + input_args: See `dense_layer_computation()`. + input_kwargs: See `dense_layer_computation()`. + tape: See `dense_layer_computation()`. + num_microbatches: See `dense_layer_computation()`. Returns: - A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the - intermediate Tensor used in the chain-rule / "fast" clipping trick, - `outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is - a function that takes one input, a `tf.Tensor` that represents the output - of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a - `tf.GradientTape` instance that records the dense layer computation and - `summed_loss` is the sum of the per-example losses of the underlying model. - This function then returns the per-example squared L2 gradient norms of the - trainable variables in `layer_instance`. These squared norms should be a 1D - `tf.Tensor` of length `batch_size`. + See `dense_layer_computation()`. """ - if len(inputs) != 1: + if input_kwargs: + raise ValueError("Embedding layer calls should not receive kwargs.") + del input_kwargs # Unused in embedding layer calls. + if len(input_args) != 1: raise ValueError("Only layer inputs of length 1 are permitted.") if hasattr(layer_instance, "sparse"): # for backwards compatibility if layer_instance.sparse: raise NotImplementedError("Sparse output tensors are not supported.") - if isinstance(inputs[0], tf.SparseTensor): + if isinstance(input_args[0], tf.SparseTensor): raise NotImplementedError("Sparse input tensors are not supported.") # Disable experimental features. @@ -256,7 +258,7 @@ def embedding_layer_computation( "The experimental embedding feature" "'_use_one_hot_matmul' is not supported." ) - input_ids = tf.cast(*inputs, tf.int32) + input_ids = tf.cast(*input_args, tf.int32) base_vars = layer_instance.trainable_variables[0] tape.watch(base_vars) outputs = tf.nn.embedding_lookup(base_vars, input_ids)