Update type annotations of gradient clipping library.

PiperOrigin-RevId: 513640655
This commit is contained in:
Walid Krichene 2023-03-02 14:28:44 -08:00 committed by A. Unique TensorFlower
parent 7436930c64
commit cbf34f2b04
3 changed files with 15 additions and 15 deletions

View file

@ -21,13 +21,13 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
from typing import Union, Iterable, Text
from typing import Dict, Iterable, Text, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
def get_registry_generator_fn(

View file

@ -13,16 +13,16 @@
# limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms."""
from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]]
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
def has_internal_compute_graph(input_object: Any):
@ -38,7 +38,7 @@ def has_internal_compute_graph(input_object: Any):
def _get_internal_layers(
input_layer: tf.keras.layers.Layer,
) -> list[tf.keras.layers.Layer]:
) -> List[tf.keras.layers.Layer]:
"""Returns a list of layers that are nested within a given layer."""
internal_layers = []
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
@ -53,7 +53,7 @@ def model_forward_pass(
input_model: tf.keras.Model,
inputs: InputTensor,
generator_fn: GeneratorFunction = None,
) -> Tuple[tf.Tensor, list[Any]]:
) -> Tuple[tf.Tensor, List[Any]]:
"""Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the
@ -158,10 +158,10 @@ def all_trainable_layers_are_registered(
def add_aggregate_noise(
input_model: tf.keras.Model,
x_batch: InputTensor,
clipped_grads: list[tf.Tensor],
clipped_grads: List[tf.Tensor],
l2_norm_clip: float,
noise_multiplier: float,
) -> list[tf.Tensor]:
) -> List[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the

View file

@ -40,7 +40,7 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
"""
from typing import Callable, Type, Any, Union, Iterable, Text
from typing import Any, Callable, Dict, Iterable, Text, Tuple, Type, Union
import tensorflow as tf
@ -49,13 +49,13 @@ import tensorflow as tf
# ==============================================================================
SquareNormFunction = Callable[[Any], tf.Tensor]
RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction]
RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction]
RegistryFunction = Callable[
[Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput
[Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput
]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
# ==============================================================================
@ -93,7 +93,7 @@ class LayerRegistry:
# ==============================================================================
def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
inputs: tuple[InputTensor],
inputs: Tuple[InputTensor],
tape: tf.GradientTape,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`.
@ -153,7 +153,7 @@ def dense_layer_computation(
def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
inputs: tuple[InputTensor],
inputs: Tuple[InputTensor],
tape: tf.GradientTape,
) -> RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`.