diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index ffa666f..e50781d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -6,6 +6,7 @@ py_library( name = "gradient_clipping_utils", srcs = ["gradient_clipping_utils.py"], srcs_version = "PY3", + deps = [":layer_registry"], ) py_library( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 6a37ae3..00e12a2 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,11 +21,18 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ +from typing import Union, Iterable, Text + import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr + +InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]] -def get_registry_generator_fn(tape, layer_registry): +def get_registry_generator_fn( + tape: tf.GradientTape, layer_registry: lr.LayerRegistry +): """Creates the generator function for `compute_gradient_norms()`.""" if layer_registry is None: # Needed for backwards compatibility. @@ -53,7 +60,12 @@ def get_registry_generator_fn(tape, layer_registry): return registry_generator_fn -def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): +def compute_gradient_norms( + input_model: tf.keras.Model, + x_batch: InputTensor, + y_batch: tf.Tensor, + layer_registry: lr.LayerRegistry, +): """Computes the per-example loss gradient norms for given data. Applies a variant of the approach given in @@ -62,7 +74,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): Args: input_model: The `tf.keras.Model` from which to obtain the layers from. The loss of the model *must* be a scalar loss. - x_batch: A `tf.Tensor` representing a batch of inputs to the model. The + x_batch: An `InputTensor` representing a batch of inputs to the model. The first axis must be the batch dimension. y_batch: A `tf.Tensor` representing a batch of output labels. The first axis must be the batch dimension. The number of examples should match the @@ -106,7 +118,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry): return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) -def compute_clip_weights(l2_norm_clip, gradient_norms): +def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): """Computes the per-example loss/clip weights for clipping. When the sum of the per-example losses is replaced a weighted sum, where @@ -132,7 +144,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms): def compute_pred_and_clipped_gradients( - input_model, x_batch, y_batch, l2_norm_clip, layer_registry + input_model: tf.keras.Model, + x_batch: InputTensor, + y_batch: tf.Tensor, + l2_norm_clip: float, + layer_registry: lr.LayerRegistry, ): """Computes the per-example predictions and per-example clipped loss gradient. @@ -147,7 +163,7 @@ def compute_pred_and_clipped_gradients( Args: input_model: The `tf.keras.Model` from which to obtain the layers from. - x_batch: A `tf.Tensor` representing a batch of inputs to the model. The + x_batch: An `InputTensor` representing a batch of inputs to the model. The first axis must be the batch dimension. y_batch: A `tf.Tensor` representing a batch of output labels. The first axis must be the batch dimension. The number of examples should match the diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 183d890..1933e21 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -13,6 +13,8 @@ # limitations under the License. import itertools +from typing import Callable, Any, List, Union + from absl.testing import parameterized import tensorflow as tf @@ -20,23 +22,35 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +# ============================================================================== +# Type aliases +# ============================================================================== +LayerGenerator = Callable[[int, int], tf.keras.layers.Layer] + +ModelGenerator = Callable[ + [LayerGenerator, Union[int, List[int]], int], tf.keras.Model +] + + # ============================================================================== # Helper functions and classes. # ============================================================================== class DoubleDense(tf.keras.layers.Layer): """Generates two dense layers nested together.""" - def __init__(self, units): + def __init__(self, units: int): super().__init__() self.dense1 = tf.keras.layers.Dense(units) self.dense2 = tf.keras.layers.Dense(1) - def call(self, inputs): + def call(self, inputs: Any): x = self.dense1(inputs) return self.dense2(x) -def double_dense_layer_computation(layer_instance, inputs, tape): +def double_dense_layer_computation( + layer_instance: tf.keras.layers.Layer, inputs: Any, tape: tf.GradientTape +): """Layer registry function for the custom `DoubleDense` layer class.""" vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( layer_instance.dense1, inputs, tape @@ -53,7 +67,9 @@ def double_dense_layer_computation(layer_instance, inputs, tape): return [vars1, vars2], outputs, sqr_norm_fn -def compute_true_gradient_norms(input_model, x_batch, y_batch): +def compute_true_gradient_norms( + input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor +): """Computes the real gradient norms for an input `(model, x, y)`.""" loss_config = input_model.loss.get_config() loss_config['reduction'] = tf.keras.losses.Reduction.NONE @@ -73,14 +89,14 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch): def get_computed_and_true_norms( - model_generator, - layer_generator, - input_dims, - output_dim, - is_eager, - x_input, - rng_seed=777, - registry=None, + model_generator: ModelGenerator, + layer_generator: LayerGenerator, + input_dims: Union[int, List[int]], + output_dim: int, + is_eager: bool, + x_input: tf.Tensor, + rng_seed: int = 777, + registry: layer_registry.LayerRegistry = None, ): """Obtains the true and computed gradient norms for a model and batch input. @@ -238,7 +254,7 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim): # ============================================================================== # Factory functions. # ============================================================================== -def get_nd_test_tensors(n): +def get_nd_test_tensors(n: int): """Returns a list of candidate tests for a given dimension n.""" return [ tf.zeros((n,), dtype=tf.float64), @@ -246,7 +262,7 @@ def get_nd_test_tensors(n): ] -def get_nd_test_batches(n): +def get_nd_test_batches(n: int): """Returns a list of candidate input batches of dimension n.""" result = [] tensors = get_nd_test_tensors(n) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 6dd0d49..f6c22f0 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,11 +13,19 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" +from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional + from absl import logging import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr -def has_internal_compute_graph(input_object): +InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]] + +GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]] + + +def has_internal_compute_graph(input_object: Any): """Checks if input is a TF model and has a TF internal compute graph.""" return ( isinstance(input_object, tf.keras.Model) @@ -28,7 +36,9 @@ def has_internal_compute_graph(input_object): ) -def _get_internal_layers(input_layer): +def _get_internal_layers( + input_layer: tf.keras.layers.Layer, +) -> list[tf.keras.layers.Layer]: """Returns a list of layers that are nested within a given layer.""" internal_layers = [] if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'): @@ -39,7 +49,11 @@ def _get_internal_layers(input_layer): return internal_layers -def model_forward_pass(input_model, inputs, generator_fn=None): +def model_forward_pass( + input_model: tf.keras.Model, + inputs: InputTensor, + generator_fn: GeneratorFunction = None, +) -> Tuple[tf.Tensor, list[Any]]: """Does a forward pass of a model and returns useful intermediates. NOTE: the graph traversal algorithm is an adaptation of the logic in the @@ -118,7 +132,9 @@ def model_forward_pass(input_model, inputs, generator_fn=None): return node_layer_outputs, generator_outputs_list -def all_trainable_layers_are_registered(input_model, layer_registry): +def all_trainable_layers_are_registered( + input_model: tf.keras.Model, layer_registry: lr.LayerRegistry +) -> bool: """Check if an input model's trainable layers are all registered. Args: @@ -140,18 +156,21 @@ def all_trainable_layers_are_registered(input_model, layer_registry): def add_aggregate_noise( - input_model, x_batch, clipped_grads, l2_norm_clip, noise_multiplier -): + input_model: tf.keras.Model, + x_batch: InputTensor, + clipped_grads: list[tf.Tensor], + l2_norm_clip: float, + noise_multiplier: float, +) -> list[tf.Tensor]: """Adds noise to a collection of clipped gradients. The magnitude of the noise depends on the aggregation strategy of the input model's loss function. Args: - input_model: The Keras model to obtain the layers from. - x_batch: A collection of Tensors to be fed into the input layer of the - model. - clipped_grads: A list of tensors representing the clipped gradients. + input_model: The `tf.keras.Model` to obtain the layers from. + x_batch: An `InputTensor` to be fed into the input layer of the model. + clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. l2_norm_clip: Clipping norm (max L2 norm of each gradient). noise_multiplier: Ratio of the standard deviation to the clipping norm. @@ -187,7 +206,9 @@ def add_aggregate_noise( return tf.nest.map_structure(add_noise, clipped_grads) -def generate_model_outputs_using_core_keras_layers(input_model): +def generate_model_outputs_using_core_keras_layers( + input_model: tf.keras.Model, +) -> tf.Tensor: """Returns the model outputs generated by only core Keras layers.""" cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects()) cust_hash_set = set([hash(v) for v in cust_obj_dict.values()]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index c8279ba..60b0d6b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -40,9 +40,24 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`. Details of this decomposition can be found in https://arxiv.org/abs/1510.01799 """ +from typing import Callable, Type, Any, Union, Iterable, Text import tensorflow as tf +# ============================================================================== +# Type aliases +# ============================================================================== +SquareNormFunction = Callable[[Any], tf.Tensor] + +RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction] + +RegistryFunction = Callable[ + [Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput +] + +InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]] + + # ============================================================================== # Main class # ============================================================================== @@ -54,15 +69,19 @@ class LayerRegistry: self._layer_class_dict = {} self._registry = {} - def is_elem(self, layer_instance): + def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool: """Checks if a layer instance's class is in the registry.""" return hash(layer_instance.__class__) in self._registry - def lookup(self, layer_instance): + def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction: """Returns the layer registry function for a given layer instance.""" return self._registry[hash(layer_instance.__class__)] - def insert(self, layer_class, layer_registry_function): + def insert( + self, + layer_class: Type[tf.keras.layers.Layer], + layer_registry_function: RegistryFunction, + ): """Inserts a layer registry function into the internal dictionaries.""" layer_key = hash(layer_class) self._layer_class_dict[layer_key] = layer_class @@ -72,7 +91,11 @@ class LayerRegistry: # ============================================================================== # Supported Keras layers # ============================================================================== -def dense_layer_computation(layer_instance, inputs, tape): +def dense_layer_computation( + layer_instance: tf.keras.layers.Dense, + inputs: tuple[InputTensor], + tape: tf.GradientTape, +) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Dense`. The logic for this computation is based on the following paper: @@ -83,8 +106,9 @@ def dense_layer_computation(layer_instance, inputs, tape): Args: layer_instance: A `tf.keras.layers.Dense` instance. - inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., - `layer_instance(inputs)` returns a valid output. + inputs: A tuple containing a single `InputTensor` which can be passed into + the layer instance, i.e., `layer_instance(*inputs)` returns a valid + output. tape: A `tf.GradientTape` instance that will be used to watch the output `base_vars`. @@ -100,6 +124,8 @@ def dense_layer_computation(layer_instance, inputs, tape): trainable variables in `layer_instance`. These squared norms should be a 1D `tf.Tensor` of length `batch_size`. """ + if len(inputs) != 1: + raise ValueError("Only layer inputs of length 1 are permitted.") orig_activation = layer_instance.activation layer_instance.activation = None base_vars = layer_instance(*inputs) @@ -125,7 +151,11 @@ def dense_layer_computation(layer_instance, inputs, tape): return base_vars, outputs, sqr_norm_fn -def embedding_layer_computation(layer_instance, inputs, tape): +def embedding_layer_computation( + layer_instance: tf.keras.layers.Embedding, + inputs: tuple[InputTensor], + tape: tf.GradientTape, +) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Embedding`. The logic of this computation is based on the `tf.keras.layers.Dense` @@ -136,8 +166,9 @@ def embedding_layer_computation(layer_instance, inputs, tape): Args: layer_instance: A `tf.keras.layers.Embedding` instance. - inputs: A `tf.Tensor` which can be passed into the layer instance, i.e., - `layer_instance(inputs)` returns a valid output. + inputs: A tuple containing a single `InputTensor` which can be passed into + the layer instance, i.e., `layer_instance(*inputs)` returns a valid + output. tape: A `tf.GradientTape` instance that will be used to watch the output `base_vars`. @@ -153,10 +184,12 @@ def embedding_layer_computation(layer_instance, inputs, tape): trainable variables in `layer_instance`. These squared norms should be a 1D `tf.Tensor` of length `batch_size`. """ + if len(inputs) != 1: + raise ValueError("Only layer inputs of length 1 are permitted.") if hasattr(layer_instance, "sparse"): # for backwards compatibility if layer_instance.sparse: raise NotImplementedError("Sparse output tensors are not supported.") - if isinstance(inputs, tf.SparseTensor): + if isinstance(inputs[0], tf.SparseTensor): raise NotImplementedError("Sparse input tensors are not supported.") # Disable experimental features. @@ -225,7 +258,7 @@ def embedding_layer_computation(layer_instance, inputs, tape): # ============================================================================== # Main factory methods # ============================================================================== -def make_default_layer_registry(): +def make_default_layer_registry() -> LayerRegistry: registry = LayerRegistry() registry.insert(tf.keras.layers.Dense, dense_layer_computation) registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)