Improve the readability of the fast gradient clipping library.

PiperOrigin-RevId: 566961891
This commit is contained in:
A. Unique TensorFlower 2023-09-20 07:47:12 -07:00
parent 97aaf302eb
commit c037070e50
18 changed files with 84 additions and 111 deletions

View file

@ -21,7 +21,8 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
from typing import List, Optional, Tuple
from collections.abc import Sequence
from typing import Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
@ -74,7 +75,7 @@ def compute_gradient_norms(
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None,
trainable_vars: Optional[Sequence[tf.Variable]] = None,
):
"""Computes the per-example loss gradient norms for given data.
@ -219,7 +220,7 @@ def compute_clipped_gradients_and_outputs(
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main

View file

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Text, Tuple
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from absl.testing import parameterized
import tensorflow as tf
@ -22,9 +23,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
# Helper functions and classes.
# ==============================================================================
class DoubleDense(tf.keras.layers.Layer):
"""Generates two dense layers nested together."""
@ -40,8 +38,8 @@ class DoubleDense(tf.keras.layers.Layer):
def double_dense_layer_computation(
layer_instance: tf.keras.layers.Layer,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[int],
) -> type_aliases.RegistryFunctionOutput:
@ -61,9 +59,6 @@ def double_dense_layer_computation(
return [vars1, vars2], outputs, sqr_norm_fn
# ==============================================================================
# Main tests.
# ==============================================================================
class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(

View file

@ -13,7 +13,8 @@
# limitations under the License.
"""A collection of common utility functions for unit testing."""
from typing import Callable, List, Optional, Tuple
from collections.abc import Callable, MutableSequence, Sequence
from typing import Optional
import numpy as np
import tensorflow as tf
@ -108,8 +109,8 @@ def compute_true_gradient_norms(
def get_model_from_generator(
model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: Sequence[int],
output_dims: Sequence[int],
is_eager: bool,
) -> tf.keras.Model:
"""Creates a simple model from input specifications."""
@ -171,8 +172,8 @@ def get_computed_and_true_norms_from_model(
def get_computed_and_true_norms(
model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: Sequence[int],
output_dims: Sequence[int],
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int],
is_eager: bool,
@ -181,7 +182,7 @@ def get_computed_and_true_norms(
rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None,
partial: bool = False,
) -> Tuple[tf.Tensor, tf.Tensor]:
) -> tuple[tf.Tensor, tf.Tensor]:
"""Obtains the true and computed gradient norms for a model and batch input.
Helpful testing wrapper function used to avoid code duplication.
@ -245,8 +246,8 @@ def reshape_and_sum(tensor: tf.Tensor) -> tf.Tensor:
# ==============================================================================
def make_one_layer_functional_model(
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: Sequence[int],
output_dims: Sequence[int],
) -> tf.keras.Model:
"""Creates a 1-layer sequential model."""
inputs = tf.keras.Input(shape=input_dims)
@ -258,8 +259,8 @@ def make_one_layer_functional_model(
def make_two_layer_functional_model(
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: Sequence[int],
output_dims: Sequence[int],
) -> tf.keras.Model:
"""Creates a 2-layer sequential model."""
inputs = tf.keras.Input(shape=input_dims)
@ -272,8 +273,8 @@ def make_two_layer_functional_model(
def make_two_tower_model(
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: Sequence[int],
output_dims: Sequence[int],
) -> tf.keras.Model:
"""Creates a 2-layer 2-input functional model."""
inputs1 = tf.keras.Input(shape=input_dims)
@ -290,8 +291,8 @@ def make_two_tower_model(
def make_bow_model(
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: Sequence[int],
output_dims: Sequence[int],
) -> tf.keras.Model:
"""Creates a simple embedding bow model."""
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
@ -315,8 +316,8 @@ def make_bow_model(
def make_dense_bow_model(
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: Sequence[int],
output_dims: Sequence[int],
) -> tf.keras.Model:
"""Creates an embedding bow model with a `Dense` layer."""
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
@ -341,8 +342,8 @@ def make_dense_bow_model(
def make_weighted_bow_model(
layer_generator: type_aliases.LayerGenerator,
input_dims: List[int],
output_dims: List[int],
input_dims: MutableSequence[int],
output_dims: MutableSequence[int],
) -> tf.keras.Model:
"""Creates a weighted embedding bow model."""
# NOTE: This model only accepts dense input tensors.
@ -353,10 +354,9 @@ def make_weighted_bow_model(
emb_layer = layer_generator(input_dims, output_dims)
if len(output_dims) != 1:
raise ValueError('Expected `output_dims` to be of size 1.')
output_dim = output_dims[0]
feature_embs = emb_layer(inputs)
# Use deterministic weights to avoid seeding issues on TPUs.
feature_shape = input_dims + [output_dim]
feature_shape = input_dims + output_dims
feature_weights = tf.expand_dims(
tf.reshape(
tf.range(np.product(feature_shape), dtype=tf.float32),

View file

@ -13,7 +13,8 @@
# limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms."""
from typing import Any, List, Optional, Set, Tuple
from collections.abc import Sequence, Set
from typing import Any, Optional
from absl import logging
import tensorflow as tf
@ -36,7 +37,7 @@ def model_forward_pass(
input_model: tf.keras.Model,
inputs: type_aliases.PackedTensors,
generator_fn: type_aliases.GeneratorFunction = None,
) -> Tuple[type_aliases.PackedTensors, List[Any]]:
) -> tuple[type_aliases.PackedTensors, Sequence[Any]]:
"""Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the
@ -149,7 +150,7 @@ def add_aggregate_noise(
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
) -> List[tf.Tensor]:
) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the

View file

@ -16,7 +16,6 @@ from typing import Any
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
@ -90,7 +89,7 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
num_dims = 3
num_inputs = 1 if input_packing_type is None else 2
num_outputs = 1 if output_packing_type is None else 2
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
sample_inputs = [tf.keras.Input((num_dims,)) for _ in range(num_inputs)]
temp_sum = tf.stack(sample_inputs, axis=0)
sample_outputs = [
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)

View file

@ -13,7 +13,8 @@
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.Dense`."""
from typing import Any, Mapping, Tuple, Union
from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
input_args: Tuple[Any, ...],
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`.

View file

@ -19,11 +19,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_dense_layer_generators():
def sigmoid_dense_layer(units):
return tf.keras.layers.Dense(units, activation='sigmoid')
@ -50,14 +46,12 @@ def get_dense_layer_registries():
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
@parameterized.product(
model_name=list(get_dense_model_generators().keys()),
@ -112,8 +106,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
)
# TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if using_tpu:
if self.using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion.
@ -127,8 +120,8 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# which is a reasonable level of error for computing gradient norms.
# Other trials also give an absolute (resp. relative) error of around
# 0.05 (resp. 0.0015).
rtol = 1e-2 if using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2
rtol = 1e-2 if self.using_tpu else 1e-3
atol = 1e-1 if self.using_tpu else 1e-2
for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0]
@ -139,7 +132,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
test_op, args=(x_batch, weight_batch)
)
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(dense_test.GradNormTest):
def setUp(self):
super().setUp()
super(dense_test.GradNormTest, self).setUp()
self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__':

View file

@ -13,7 +13,8 @@
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.Embedding`."""
from typing import Any, Mapping, Tuple, Union
from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
input_args: Tuple[Any, ...],
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`.

View file

@ -20,9 +20,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_embedding_model_generators():
return {
'bow1': common_test_utils.make_bow_model,
@ -59,14 +56,12 @@ def get_embedding_layer_registries():
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings.
@ -101,13 +96,12 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# The following are invalid test combinations and, hence, are skipped.
batch_size = embed_indices.shape[0]
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if (
(num_microbatches is not None and batch_size % num_microbatches != 0)
or (model_name == 'weighted_bow1' and is_ragged)
or (
# Current clipping ops do not have corresponding TPU kernels.
using_tpu
self.using_tpu
and is_ragged
)
):
@ -139,7 +133,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
)
# TPUs can only run `tf.function`-decorated functions.
if using_tpu:
if self.using_tpu:
test_op = tf.function(test_op, autograph=False)
# Set up the device ops and run the test.
@ -147,7 +141,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
test_op, args=(embed_indices,)
)
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(embedding_test.GradNormTest):
def setUp(self):
super().setUp()
super(embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__':

View file

@ -13,21 +13,19 @@
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""
from typing import Any, Mapping, Tuple, Union
from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def layer_normalization_computation(
layer_instance: tf.keras.layers.LayerNormalization,
input_args: Tuple[Any, ...],
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.LayerNormalization`.

View file

@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_layer_norm_layer_generators():
return {
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
@ -73,14 +70,12 @@ def get_layer_norm_registries():
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
@parameterized.product(
model_name=list(get_layer_norm_model_generators().keys()),
@ -130,14 +125,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
)
# TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if using_tpu:
if self.using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion (see
# `dense_test.py` for additional discussions).
rtol = 1e-2 if using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2
rtol = 1e-2 if self.using_tpu else 1e-3
atol = 1e-1 if self.using_tpu else 1e-2
# Each batched input is a reshape of a `tf.range()` call.
batch_size = 2
@ -148,7 +142,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(layer_normalization_test.GradNormTest):
def setUp(self):
super().setUp()
super(layer_normalization_test.GradNormTest, self).setUp()
self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__':

View file

@ -13,7 +13,8 @@
# limitations under the License.
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
from typing import Any, Dict, Optional, Tuple
from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
@ -21,8 +22,8 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
def nlp_on_device_embedding_layer_computation(
layer_instance: tf.keras.layers.Layer,
input_args: Tuple[Any, ...],
input_kwargs: Dict[str, Any],
input_args: Sequence[Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:

View file

@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_nlp_on_device_embedding_model_generators():
return {
'bow1': common_test_utils.make_bow_model,
@ -49,30 +46,26 @@ def get_nlp_on_device_embedding_layer_registries():
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
dbl_registry.insert(
tfm.nlp.layers.OnDeviceEmbedding,
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation,
)
return {
'embed_and_dense': dbl_registry,
}
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings.
@parameterized.product(
input_data=get_nlp_on_device_embedding_inputs(),
scale_factor=[None, 0.5, 1.0],
model_name=list(
get_nlp_on_device_embedding_model_generators().keys()
),
model_name=list(get_nlp_on_device_embedding_model_generators().keys()),
output_dim=[2],
layer_registry_name=list(
get_nlp_on_device_embedding_layer_registries().keys()
@ -97,7 +90,6 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# The following are invalid test combinations and, hence, are skipped.
batch_size = embed_indices.shape[0]
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if num_microbatches is not None and batch_size % num_microbatches != 0:
return
@ -106,9 +98,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def embed_layer_generator(_, output_dims):
return tfm.nlp.layers.OnDeviceEmbedding(
10,
*output_dims,
scale_factor=scale_factor
10, *output_dims, scale_factor=scale_factor
)
model = common_test_utils.get_model_from_generator(
@ -137,7 +127,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
)
# TPUs can only run `tf.function`-decorated functions.
if using_tpu:
if self.using_tpu:
test_op = tf.function(test_op, autograph=False)
# Set up the device ops and run the test.
@ -145,7 +135,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
test_op, args=(embed_indices,)
)
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
def setUp(self):
super().setUp()
super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__':

View file

@ -13,12 +13,13 @@
# limitations under the License.
"""A collection of type aliases used throughout the clipping library."""
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, Optional, Union
import tensorflow as tf
# Tensorflow aliases.
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]]
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
InputTensors = PackedTensors
@ -31,12 +32,12 @@ LossFn = Callable[..., tf.Tensor]
# Layer Registry aliases.
SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
RegistryFunctionOutput = tuple[Any, OutputTensors, SquareNormFunction]
RegistryFunction = Callable[
[
Any,
Tuple[Any, ...],
tuple[Any, ...],
Mapping[str, Any],
tf.GradientTape,
Union[tf.Tensor, None],
@ -45,11 +46,11 @@ RegistryFunction = Callable[
]
# Clipping aliases.
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
GeneratorFunction = Optional[Callable[[Any, tuple, Mapping], tuple[Any, Any]]]
# Testing aliases.
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
ModelGenerator = Callable[
[LayerGenerator, List[int], List[int]], tf.keras.Model
[LayerGenerator, Sequence[int], Sequence[int]], tf.keras.Model
]