Update type annotations of gradient clipping library.
PiperOrigin-RevId: 513640655
This commit is contained in:
parent
7436930c64
commit
cbf34f2b04
3 changed files with 15 additions and 15 deletions
|
@ -21,13 +21,13 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||||
`compute_gradient_norms()` function).
|
`compute_gradient_norms()` function).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Union, Iterable, Text
|
from typing import Dict, Iterable, Text, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
|
||||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
def get_registry_generator_fn(
|
def get_registry_generator_fn(
|
||||||
|
|
|
@ -13,16 +13,16 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||||
|
|
||||||
from typing import Any, Union, Iterable, Text, Callable, Tuple, Optional
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
|
||||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
GeneratorFunction = Optional[Callable[[Any, Tuple, dict], Tuple[Any, Any]]]
|
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
|
||||||
|
|
||||||
|
|
||||||
def has_internal_compute_graph(input_object: Any):
|
def has_internal_compute_graph(input_object: Any):
|
||||||
|
@ -38,7 +38,7 @@ def has_internal_compute_graph(input_object: Any):
|
||||||
|
|
||||||
def _get_internal_layers(
|
def _get_internal_layers(
|
||||||
input_layer: tf.keras.layers.Layer,
|
input_layer: tf.keras.layers.Layer,
|
||||||
) -> list[tf.keras.layers.Layer]:
|
) -> List[tf.keras.layers.Layer]:
|
||||||
"""Returns a list of layers that are nested within a given layer."""
|
"""Returns a list of layers that are nested within a given layer."""
|
||||||
internal_layers = []
|
internal_layers = []
|
||||||
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
||||||
|
@ -53,7 +53,7 @@ def model_forward_pass(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
inputs: InputTensor,
|
inputs: InputTensor,
|
||||||
generator_fn: GeneratorFunction = None,
|
generator_fn: GeneratorFunction = None,
|
||||||
) -> Tuple[tf.Tensor, list[Any]]:
|
) -> Tuple[tf.Tensor, List[Any]]:
|
||||||
"""Does a forward pass of a model and returns useful intermediates.
|
"""Does a forward pass of a model and returns useful intermediates.
|
||||||
|
|
||||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||||
|
@ -158,10 +158,10 @@ def all_trainable_layers_are_registered(
|
||||||
def add_aggregate_noise(
|
def add_aggregate_noise(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
x_batch: InputTensor,
|
x_batch: InputTensor,
|
||||||
clipped_grads: list[tf.Tensor],
|
clipped_grads: List[tf.Tensor],
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
noise_multiplier: float,
|
noise_multiplier: float,
|
||||||
) -> list[tf.Tensor]:
|
) -> List[tf.Tensor]:
|
||||||
"""Adds noise to a collection of clipped gradients.
|
"""Adds noise to a collection of clipped gradients.
|
||||||
|
|
||||||
The magnitude of the noise depends on the aggregation strategy of the
|
The magnitude of the noise depends on the aggregation strategy of the
|
||||||
|
|
|
@ -40,7 +40,7 @@ where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
|
||||||
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Callable, Type, Any, Union, Iterable, Text
|
from typing import Any, Callable, Dict, Iterable, Text, Tuple, Type, Union
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,13 +49,13 @@ import tensorflow as tf
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
SquareNormFunction = Callable[[Any], tf.Tensor]
|
SquareNormFunction = Callable[[Any], tf.Tensor]
|
||||||
|
|
||||||
RegistryFunctionOutput = tuple[Any, tf.Tensor, SquareNormFunction]
|
RegistryFunctionOutput = Tuple[Any, tf.Tensor, SquareNormFunction]
|
||||||
|
|
||||||
RegistryFunction = Callable[
|
RegistryFunction = Callable[
|
||||||
[Any, tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
[Any, Tuple[Any], tf.GradientTape], RegistryFunctionOutput
|
||||||
]
|
]
|
||||||
|
|
||||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]]
|
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -93,7 +93,7 @@ class LayerRegistry:
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def dense_layer_computation(
|
def dense_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Dense,
|
layer_instance: tf.keras.layers.Dense,
|
||||||
inputs: tuple[InputTensor],
|
inputs: Tuple[InputTensor],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
) -> RegistryFunctionOutput:
|
) -> RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Dense`.
|
"""Registry function for `tf.keras.layers.Dense`.
|
||||||
|
@ -153,7 +153,7 @@ def dense_layer_computation(
|
||||||
|
|
||||||
def embedding_layer_computation(
|
def embedding_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Embedding,
|
layer_instance: tf.keras.layers.Embedding,
|
||||||
inputs: tuple[InputTensor],
|
inputs: Tuple[InputTensor],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
) -> RegistryFunctionOutput:
|
) -> RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Embedding`.
|
"""Registry function for `tf.keras.layers.Embedding`.
|
||||||
|
|
Loading…
Reference in a new issue