From cbf34f2b04d9d9761bba30469c819709174e43ff Mon Sep 17 00:00:00 2001 From: Walid Krichene Date: Thu, 2 Mar 2023 14:28:44 -0800 Subject: [PATCH] Update type annotations of gradient clipping library. PiperOrigin-RevId: 513640655 --- .../privacy/fast_gradient_clipping/clip_grads.py | 4 ++-- .../gradient_clipping_utils.py | 14 +++++++------- .../fast_gradient_clipping/layer_registry.py | 12 ++++++------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 00e12a2..4af6695 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,13 +21,13 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ -from typing import Union, Iterable, Text +from typing import Dict, Iterable, Text, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr -InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]] +InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] def get_registry_generator_fn( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index f6c22f0..428dc0f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,16 +13,16 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" -from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union from absl import logging import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr -InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]] +InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] -GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]] +GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]] def has_internal_compute_graph(input_object: Any): @@ -38,7 +38,7 @@ def has_internal_compute_graph(input_object: Any): def _get_internal_layers( input_layer: tf.keras.layers.Layer, -) -> list[tf.keras.layers.Layer]: +) -> List[tf.keras.layers.Layer]: """Returns a list of layers that are nested within a given layer.""" internal_layers = [] if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'): @@ -53,7 +53,7 @@ def model_forward_pass( input_model: tf.keras.Model, inputs: InputTensor, generator_fn: GeneratorFunction = None, -) -> Tuple[tf.Tensor, list[Any]]: +) -> Tuple[tf.Tensor, List[Any]]: """Does a forward pass of a model and returns useful intermediates. NOTE: the graph traversal algorithm is an adaptation of the logic in the @@ -158,10 +158,10 @@ def all_trainable_layers_are_registered( def add_aggregate_noise( input_model: tf.keras.Model, x_batch: InputTensor, - clipped_grads: list[tf.Tensor], + clipped_grads: List[tf.Tensor], l2_norm_clip: float, noise_multiplier: float, -) -> list[tf.Tensor]: +) -> List[tf.Tensor]: """Adds noise to a collection of clipped gradients. The magnitude of the noise depends on the aggregation strategy of the diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index 60b0d6b..556fcc4 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -40,7 +40,7 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`. Details of this decomposition can be found in https://arxiv.org/abs/1510.01799 """ -from typing import Callable, Type, Any, Union, Iterable, Text +from typing import Any, Callable, Dict, Iterable, Text, Tuple, Type, Union import tensorflow as tf @@ -49,13 +49,13 @@ import tensorflow as tf # ============================================================================== SquareNormFunction = Callable[[Any], tf.Tensor] -RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction] +RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction] RegistryFunction = Callable[ - [Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput + [Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput ] -InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]] +InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] # ============================================================================== @@ -93,7 +93,7 @@ class LayerRegistry: # ============================================================================== def dense_layer_computation( layer_instance: tf.keras.layers.Dense, - inputs: tuple[InputTensor], + inputs: Tuple[InputTensor], tape: tf.GradientTape, ) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Dense`. @@ -153,7 +153,7 @@ def dense_layer_computation( def embedding_layer_computation( layer_instance: tf.keras.layers.Embedding, - inputs: tuple[InputTensor], + inputs: Tuple[InputTensor], tape: tf.GradientTape, ) -> RegistryFunctionOutput: """Registry function for `tf.keras.layers.Embedding`.