Update type annotations of gradient clipping library.
PiperOrigin-RevId: 513640655
This commit is contained in:
parent
7436930c64
commit
cbf34f2b04
3 changed files with 15 additions and 15 deletions
|
@ -21,13 +21,13 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
|||
`compute_gradient_norms()` function).
|
||||
"""
|
||||
|
||||
from typing import Union, Iterable, Text
|
||||
from typing import Dict, Iterable, Text, Union
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
|
||||
|
||||
def get_registry_generator_fn(
|
||||
|
|
|
@ -13,16 +13,16 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
|
||||
GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]]
|
||||
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
|
||||
|
||||
|
||||
def has_internal_compute_graph(input_object: Any):
|
||||
|
@ -38,7 +38,7 @@ def has_internal_compute_graph(input_object: Any):
|
|||
|
||||
def _get_internal_layers(
|
||||
input_layer: tf.keras.layers.Layer,
|
||||
) -> list[tf.keras.layers.Layer]:
|
||||
) -> List[tf.keras.layers.Layer]:
|
||||
"""Returns a list of layers that are nested within a given layer."""
|
||||
internal_layers = []
|
||||
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
||||
|
@ -53,7 +53,7 @@ def model_forward_pass(
|
|||
input_model: tf.keras.Model,
|
||||
inputs: InputTensor,
|
||||
generator_fn: GeneratorFunction = None,
|
||||
) -> Tuple[tf.Tensor, list[Any]]:
|
||||
) -> Tuple[tf.Tensor, List[Any]]:
|
||||
"""Does a forward pass of a model and returns useful intermediates.
|
||||
|
||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||
|
@ -158,10 +158,10 @@ def all_trainable_layers_are_registered(
|
|||
def add_aggregate_noise(
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: InputTensor,
|
||||
clipped_grads: list[tf.Tensor],
|
||||
clipped_grads: List[tf.Tensor],
|
||||
l2_norm_clip: float,
|
||||
noise_multiplier: float,
|
||||
) -> list[tf.Tensor]:
|
||||
) -> List[tf.Tensor]:
|
||||
"""Adds noise to a collection of clipped gradients.
|
||||
|
||||
The magnitude of the noise depends on the aggregation strategy of the
|
||||
|
|
|
@ -40,7 +40,7 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
|
|||
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
||||
"""
|
||||
|
||||
from typing import Callable, Type, Any, Union, Iterable, Text
|
||||
from typing import Any, Callable, Dict, Iterable, Text, Tuple, Type, Union
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -49,13 +49,13 @@ import tensorflow as tf
|
|||
# ==============================================================================
|
||||
SquareNormFunction = Callable[[Any], tf.Tensor]
|
||||
|
||||
RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction]
|
||||
RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction]
|
||||
|
||||
RegistryFunction = Callable[
|
||||
[Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
||||
[Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
||||
]
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -93,7 +93,7 @@ class LayerRegistry:
|
|||
# ==============================================================================
|
||||
def dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Dense,
|
||||
inputs: tuple[InputTensor],
|
||||
inputs: Tuple[InputTensor],
|
||||
tape: tf.GradientTape,
|
||||
) -> RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Dense`.
|
||||
|
@ -153,7 +153,7 @@ def dense_layer_computation(
|
|||
|
||||
def embedding_layer_computation(
|
||||
layer_instance: tf.keras.layers.Embedding,
|
||||
inputs: tuple[InputTensor],
|
||||
inputs: Tuple[InputTensor],
|
||||
tape: tf.GradientTape,
|
||||
) -> RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Embedding`.
|
||||
|
|
Loading…
Reference in a new issue