forked from 626_privacy/tensorflow_privacy
Improve the readability of the fast gradient clipping library.
PiperOrigin-RevId: 566961891
This commit is contained in:
parent
97aaf302eb
commit
c037070e50
18 changed files with 84 additions and 111 deletions
|
@ -21,7 +21,8 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||||
`compute_gradient_norms()` function).
|
`compute_gradient_norms()` function).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
|
@ -74,7 +75,7 @@ def compute_gradient_norms(
|
||||||
weight_batch: Optional[tf.Tensor] = None,
|
weight_batch: Optional[tf.Tensor] = None,
|
||||||
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
trainable_vars: Optional[List[tf.Variable]] = None,
|
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||||
):
|
):
|
||||||
"""Computes the per-example loss gradient norms for given data.
|
"""Computes the per-example loss gradient norms for given data.
|
||||||
|
|
||||||
|
@ -219,7 +220,7 @@ def compute_clipped_gradients_and_outputs(
|
||||||
weight_batch: Optional[tf.Tensor] = None,
|
weight_batch: Optional[tf.Tensor] = None,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
clipping_loss: Optional[type_aliases.LossFn] = None,
|
clipping_loss: Optional[type_aliases.LossFn] = None,
|
||||||
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
|
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
|
||||||
"""Computes the per-example clipped loss gradient and other useful outputs.
|
"""Computes the per-example clipped loss gradient and other useful outputs.
|
||||||
|
|
||||||
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
|
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
|
||||||
|
|
|
@ -12,7 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Text, Tuple
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -22,9 +23,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Helper functions and classes.
|
|
||||||
# ==============================================================================
|
|
||||||
class DoubleDense(tf.keras.layers.Layer):
|
class DoubleDense(tf.keras.layers.Layer):
|
||||||
"""Generates two dense layers nested together."""
|
"""Generates two dense layers nested together."""
|
||||||
|
|
||||||
|
@ -40,8 +38,8 @@ class DoubleDense(tf.keras.layers.Layer):
|
||||||
|
|
||||||
def double_dense_layer_computation(
|
def double_dense_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Layer,
|
layer_instance: tf.keras.layers.Layer,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Sequence[Any],
|
||||||
input_kwargs: Dict[Text, Any],
|
input_kwargs: Mapping[str, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Optional[int],
|
num_microbatches: Optional[int],
|
||||||
) -> type_aliases.RegistryFunctionOutput:
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
|
@ -61,9 +59,6 @@ def double_dense_layer_computation(
|
||||||
return [vars1, vars2], outputs, sqr_norm_fn
|
return [vars1, vars2], outputs, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Main tests.
|
|
||||||
# ==============================================================================
|
|
||||||
class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase):
|
class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""A collection of common utility functions for unit testing."""
|
"""A collection of common utility functions for unit testing."""
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple
|
from collections.abc import Callable, MutableSequence, Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -108,8 +109,8 @@ def compute_true_gradient_norms(
|
||||||
def get_model_from_generator(
|
def get_model_from_generator(
|
||||||
model_generator: type_aliases.ModelGenerator,
|
model_generator: type_aliases.ModelGenerator,
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: Sequence[int],
|
||||||
output_dims: List[int],
|
output_dims: Sequence[int],
|
||||||
is_eager: bool,
|
is_eager: bool,
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a simple model from input specifications."""
|
"""Creates a simple model from input specifications."""
|
||||||
|
@ -171,8 +172,8 @@ def get_computed_and_true_norms_from_model(
|
||||||
def get_computed_and_true_norms(
|
def get_computed_and_true_norms(
|
||||||
model_generator: type_aliases.ModelGenerator,
|
model_generator: type_aliases.ModelGenerator,
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: Sequence[int],
|
||||||
output_dims: List[int],
|
output_dims: Sequence[int],
|
||||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||||
num_microbatches: Optional[int],
|
num_microbatches: Optional[int],
|
||||||
is_eager: bool,
|
is_eager: bool,
|
||||||
|
@ -181,7 +182,7 @@ def get_computed_and_true_norms(
|
||||||
rng_seed: int = 777,
|
rng_seed: int = 777,
|
||||||
registry: layer_registry.LayerRegistry = None,
|
registry: layer_registry.LayerRegistry = None,
|
||||||
partial: bool = False,
|
partial: bool = False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor]:
|
) -> tuple[tf.Tensor, tf.Tensor]:
|
||||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||||
|
|
||||||
Helpful testing wrapper function used to avoid code duplication.
|
Helpful testing wrapper function used to avoid code duplication.
|
||||||
|
@ -245,8 +246,8 @@ def reshape_and_sum(tensor: tf.Tensor) -> tf.Tensor:
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
def make_one_layer_functional_model(
|
def make_one_layer_functional_model(
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: Sequence[int],
|
||||||
output_dims: List[int],
|
output_dims: Sequence[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a 1-layer sequential model."""
|
"""Creates a 1-layer sequential model."""
|
||||||
inputs = tf.keras.Input(shape=input_dims)
|
inputs = tf.keras.Input(shape=input_dims)
|
||||||
|
@ -258,8 +259,8 @@ def make_one_layer_functional_model(
|
||||||
|
|
||||||
def make_two_layer_functional_model(
|
def make_two_layer_functional_model(
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: Sequence[int],
|
||||||
output_dims: List[int],
|
output_dims: Sequence[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a 2-layer sequential model."""
|
"""Creates a 2-layer sequential model."""
|
||||||
inputs = tf.keras.Input(shape=input_dims)
|
inputs = tf.keras.Input(shape=input_dims)
|
||||||
|
@ -272,8 +273,8 @@ def make_two_layer_functional_model(
|
||||||
|
|
||||||
def make_two_tower_model(
|
def make_two_tower_model(
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: Sequence[int],
|
||||||
output_dims: List[int],
|
output_dims: Sequence[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a 2-layer 2-input functional model."""
|
"""Creates a 2-layer 2-input functional model."""
|
||||||
inputs1 = tf.keras.Input(shape=input_dims)
|
inputs1 = tf.keras.Input(shape=input_dims)
|
||||||
|
@ -290,8 +291,8 @@ def make_two_tower_model(
|
||||||
|
|
||||||
def make_bow_model(
|
def make_bow_model(
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: Sequence[int],
|
||||||
output_dims: List[int],
|
output_dims: Sequence[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a simple embedding bow model."""
|
"""Creates a simple embedding bow model."""
|
||||||
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||||
|
@ -315,8 +316,8 @@ def make_bow_model(
|
||||||
|
|
||||||
def make_dense_bow_model(
|
def make_dense_bow_model(
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: Sequence[int],
|
||||||
output_dims: List[int],
|
output_dims: Sequence[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates an embedding bow model with a `Dense` layer."""
|
"""Creates an embedding bow model with a `Dense` layer."""
|
||||||
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||||
|
@ -341,8 +342,8 @@ def make_dense_bow_model(
|
||||||
|
|
||||||
def make_weighted_bow_model(
|
def make_weighted_bow_model(
|
||||||
layer_generator: type_aliases.LayerGenerator,
|
layer_generator: type_aliases.LayerGenerator,
|
||||||
input_dims: List[int],
|
input_dims: MutableSequence[int],
|
||||||
output_dims: List[int],
|
output_dims: MutableSequence[int],
|
||||||
) -> tf.keras.Model:
|
) -> tf.keras.Model:
|
||||||
"""Creates a weighted embedding bow model."""
|
"""Creates a weighted embedding bow model."""
|
||||||
# NOTE: This model only accepts dense input tensors.
|
# NOTE: This model only accepts dense input tensors.
|
||||||
|
@ -353,10 +354,9 @@ def make_weighted_bow_model(
|
||||||
emb_layer = layer_generator(input_dims, output_dims)
|
emb_layer = layer_generator(input_dims, output_dims)
|
||||||
if len(output_dims) != 1:
|
if len(output_dims) != 1:
|
||||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||||
output_dim = output_dims[0]
|
|
||||||
feature_embs = emb_layer(inputs)
|
feature_embs = emb_layer(inputs)
|
||||||
# Use deterministic weights to avoid seeding issues on TPUs.
|
# Use deterministic weights to avoid seeding issues on TPUs.
|
||||||
feature_shape = input_dims + [output_dim]
|
feature_shape = input_dims + output_dims
|
||||||
feature_weights = tf.expand_dims(
|
feature_weights = tf.expand_dims(
|
||||||
tf.reshape(
|
tf.reshape(
|
||||||
tf.range(np.product(feature_shape), dtype=tf.float32),
|
tf.range(np.product(feature_shape), dtype=tf.float32),
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||||
|
|
||||||
from typing import Any, List, Optional, Set, Tuple
|
from collections.abc import Sequence, Set
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -36,7 +37,7 @@ def model_forward_pass(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
inputs: type_aliases.PackedTensors,
|
inputs: type_aliases.PackedTensors,
|
||||||
generator_fn: type_aliases.GeneratorFunction = None,
|
generator_fn: type_aliases.GeneratorFunction = None,
|
||||||
) -> Tuple[type_aliases.PackedTensors, List[Any]]:
|
) -> tuple[type_aliases.PackedTensors, Sequence[Any]]:
|
||||||
"""Does a forward pass of a model and returns useful intermediates.
|
"""Does a forward pass of a model and returns useful intermediates.
|
||||||
|
|
||||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||||
|
@ -149,7 +150,7 @@ def add_aggregate_noise(
|
||||||
batch_size: tf.Tensor,
|
batch_size: tf.Tensor,
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
noise_multiplier: float,
|
noise_multiplier: float,
|
||||||
) -> List[tf.Tensor]:
|
) -> Sequence[tf.Tensor]:
|
||||||
"""Adds noise to a collection of clipped gradients.
|
"""Adds noise to a collection of clipped gradients.
|
||||||
|
|
||||||
The magnitude of the noise depends on the aggregation strategy of the
|
The magnitude of the noise depends on the aggregation strategy of the
|
||||||
|
|
|
@ -16,7 +16,6 @@ from typing import Any
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,7 +89,7 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
num_dims = 3
|
num_dims = 3
|
||||||
num_inputs = 1 if input_packing_type is None else 2
|
num_inputs = 1 if input_packing_type is None else 2
|
||||||
num_outputs = 1 if output_packing_type is None else 2
|
num_outputs = 1 if output_packing_type is None else 2
|
||||||
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
|
sample_inputs = [tf.keras.Input((num_dims,)) for _ in range(num_inputs)]
|
||||||
temp_sum = tf.stack(sample_inputs, axis=0)
|
temp_sum = tf.stack(sample_inputs, axis=0)
|
||||||
sample_outputs = [
|
sample_outputs = [
|
||||||
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
|
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast clipping function for `tf.keras.layers.Dense`."""
|
"""Fast clipping function for `tf.keras.layers.Dense`."""
|
||||||
|
|
||||||
from typing import Any, Mapping, Tuple, Union
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
def dense_layer_computation(
|
def dense_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Dense,
|
layer_instance: tf.keras.layers.Dense,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Sequence[Any],
|
||||||
input_kwargs: Mapping[str, Any],
|
input_kwargs: Mapping[str, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Union[tf.Tensor, None] = None,
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
) -> type_aliases.RegistryFunctionOutput:
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Dense`.
|
"""Registry function for `tf.keras.layers.Dense`.
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Helper functions.
|
|
||||||
# ==============================================================================
|
|
||||||
def get_dense_layer_generators():
|
def get_dense_layer_generators():
|
||||||
|
|
||||||
def sigmoid_dense_layer(units):
|
def sigmoid_dense_layer(units):
|
||||||
return tf.keras.layers.Dense(units, activation='sigmoid')
|
return tf.keras.layers.Dense(units, activation='sigmoid')
|
||||||
|
|
||||||
|
@ -50,14 +46,12 @@ def get_dense_layer_registries():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Main tests.
|
|
||||||
# ==============================================================================
|
|
||||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.strategy = tf.distribute.get_strategy()
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
self.using_tpu = False
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
model_name=list(get_dense_model_generators().keys()),
|
model_name=list(get_dense_model_generators().keys()),
|
||||||
|
@ -112,8 +106,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TPUs can only run `tf.function`-decorated functions.
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
if self.using_tpu:
|
||||||
if using_tpu:
|
|
||||||
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||||
|
|
||||||
# TPUs use lower precision than CPUs, so we relax our criterion.
|
# TPUs use lower precision than CPUs, so we relax our criterion.
|
||||||
|
@ -127,8 +120,8 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# which is a reasonable level of error for computing gradient norms.
|
# which is a reasonable level of error for computing gradient norms.
|
||||||
# Other trials also give an absolute (resp. relative) error of around
|
# Other trials also give an absolute (resp. relative) error of around
|
||||||
# 0.05 (resp. 0.0015).
|
# 0.05 (resp. 0.0015).
|
||||||
rtol = 1e-2 if using_tpu else 1e-3
|
rtol = 1e-2 if self.using_tpu else 1e-3
|
||||||
atol = 1e-1 if using_tpu else 1e-2
|
atol = 1e-1 if self.using_tpu else 1e-2
|
||||||
|
|
||||||
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
||||||
batch_size = x_batch.shape[0]
|
batch_size = x_batch.shape[0]
|
||||||
|
@ -139,7 +132,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
test_op, args=(x_batch, weight_batch)
|
test_op, args=(x_batch, weight_batch)
|
||||||
)
|
)
|
||||||
# TPUs return replica contexts, which must be unwrapped.
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
if using_tpu:
|
if self.using_tpu:
|
||||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
computed_norms = computed_norms.values[0]
|
computed_norms = computed_norms.values[0]
|
||||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(dense_test.GradNormTest):
|
class GradNormTpuTest(dense_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super(dense_test.GradNormTest, self).setUp()
|
||||||
self.strategy = ctu.create_tpu_strategy()
|
self.strategy = ctu.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
self.using_tpu = True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast clipping function for `tf.keras.layers.Embedding`."""
|
"""Fast clipping function for `tf.keras.layers.Embedding`."""
|
||||||
|
|
||||||
from typing import Any, Mapping, Tuple, Union
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||||
|
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
|
|
||||||
def embedding_layer_computation(
|
def embedding_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Embedding,
|
layer_instance: tf.keras.layers.Embedding,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Sequence[Any],
|
||||||
input_kwargs: Mapping[str, Any],
|
input_kwargs: Mapping[str, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Union[tf.Tensor, None] = None,
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
) -> type_aliases.RegistryFunctionOutput:
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.Embedding`.
|
"""Registry function for `tf.keras.layers.Embedding`.
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Helper functions.
|
|
||||||
# ==============================================================================
|
|
||||||
def get_embedding_model_generators():
|
def get_embedding_model_generators():
|
||||||
return {
|
return {
|
||||||
'bow1': common_test_utils.make_bow_model,
|
'bow1': common_test_utils.make_bow_model,
|
||||||
|
@ -59,14 +56,12 @@ def get_embedding_layer_registries():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Main tests.
|
|
||||||
# ==============================================================================
|
|
||||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.strategy = tf.distribute.get_strategy()
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
self.using_tpu = False
|
||||||
|
|
||||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||||
# supports them for embeddings.
|
# supports them for embeddings.
|
||||||
|
@ -101,13 +96,12 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# The following are invalid test combinations and, hence, are skipped.
|
# The following are invalid test combinations and, hence, are skipped.
|
||||||
batch_size = embed_indices.shape[0]
|
batch_size = embed_indices.shape[0]
|
||||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
|
||||||
if (
|
if (
|
||||||
(num_microbatches is not None and batch_size % num_microbatches != 0)
|
(num_microbatches is not None and batch_size % num_microbatches != 0)
|
||||||
or (model_name == 'weighted_bow1' and is_ragged)
|
or (model_name == 'weighted_bow1' and is_ragged)
|
||||||
or (
|
or (
|
||||||
# Current clipping ops do not have corresponding TPU kernels.
|
# Current clipping ops do not have corresponding TPU kernels.
|
||||||
using_tpu
|
self.using_tpu
|
||||||
and is_ragged
|
and is_ragged
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
@ -139,7 +133,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TPUs can only run `tf.function`-decorated functions.
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
if using_tpu:
|
if self.using_tpu:
|
||||||
test_op = tf.function(test_op, autograph=False)
|
test_op = tf.function(test_op, autograph=False)
|
||||||
|
|
||||||
# Set up the device ops and run the test.
|
# Set up the device ops and run the test.
|
||||||
|
@ -147,7 +141,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
test_op, args=(embed_indices,)
|
test_op, args=(embed_indices,)
|
||||||
)
|
)
|
||||||
# TPUs return replica contexts, which must be unwrapped.
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
if using_tpu:
|
if self.using_tpu:
|
||||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
computed_norms = computed_norms.values[0]
|
computed_norms = computed_norms.values[0]
|
||||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(embedding_test.GradNormTest):
|
class GradNormTpuTest(embedding_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super(embedding_test.GradNormTest, self).setUp()
|
||||||
self.strategy = common_test_utils.create_tpu_strategy()
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
self.using_tpu = True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -13,21 +13,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""
|
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""
|
||||||
|
|
||||||
from typing import Any, Mapping, Tuple, Union
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Supported Keras layers
|
|
||||||
# ==============================================================================
|
|
||||||
def layer_normalization_computation(
|
def layer_normalization_computation(
|
||||||
layer_instance: tf.keras.layers.LayerNormalization,
|
layer_instance: tf.keras.layers.LayerNormalization,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Sequence[Any],
|
||||||
input_kwargs: Mapping[str, Any],
|
input_kwargs: Mapping[str, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Union[tf.Tensor, None] = None,
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
) -> type_aliases.RegistryFunctionOutput:
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
"""Registry function for `tf.keras.layers.LayerNormalization`.
|
"""Registry function for `tf.keras.layers.LayerNormalization`.
|
||||||
|
|
||||||
|
|
|
@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Helper functions.
|
|
||||||
# ==============================================================================
|
|
||||||
def get_layer_norm_layer_generators():
|
def get_layer_norm_layer_generators():
|
||||||
return {
|
return {
|
||||||
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
|
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
|
||||||
|
@ -73,14 +70,12 @@ def get_layer_norm_registries():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Main tests.
|
|
||||||
# ==============================================================================
|
|
||||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.strategy = tf.distribute.get_strategy()
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
self.using_tpu = False
|
||||||
|
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
model_name=list(get_layer_norm_model_generators().keys()),
|
model_name=list(get_layer_norm_model_generators().keys()),
|
||||||
|
@ -130,14 +125,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TPUs can only run `tf.function`-decorated functions.
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
if self.using_tpu:
|
||||||
if using_tpu:
|
|
||||||
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||||
|
|
||||||
# TPUs use lower precision than CPUs, so we relax our criterion (see
|
# TPUs use lower precision than CPUs, so we relax our criterion (see
|
||||||
# `dense_test.py` for additional discussions).
|
# `dense_test.py` for additional discussions).
|
||||||
rtol = 1e-2 if using_tpu else 1e-3
|
rtol = 1e-2 if self.using_tpu else 1e-3
|
||||||
atol = 1e-1 if using_tpu else 1e-2
|
atol = 1e-1 if self.using_tpu else 1e-2
|
||||||
|
|
||||||
# Each batched input is a reshape of a `tf.range()` call.
|
# Each batched input is a reshape of a `tf.range()` call.
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
@ -148,7 +142,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# Set up the device ops and run the test.
|
# Set up the device ops and run the test.
|
||||||
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
|
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
|
||||||
# TPUs return replica contexts, which must be unwrapped.
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
if using_tpu:
|
if self.using_tpu:
|
||||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
computed_norms = computed_norms.values[0]
|
computed_norms = computed_norms.values[0]
|
||||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(layer_normalization_test.GradNormTest):
|
class GradNormTpuTest(layer_normalization_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super(layer_normalization_test.GradNormTest, self).setUp()
|
||||||
self.strategy = ctu.create_tpu_strategy()
|
self.strategy = ctu.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
self.using_tpu = True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
|
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||||
|
@ -21,8 +22,8 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
|
|
||||||
def nlp_on_device_embedding_layer_computation(
|
def nlp_on_device_embedding_layer_computation(
|
||||||
layer_instance: tf.keras.layers.Layer,
|
layer_instance: tf.keras.layers.Layer,
|
||||||
input_args: Tuple[Any, ...],
|
input_args: Sequence[Any],
|
||||||
input_kwargs: Dict[str, Any],
|
input_kwargs: Mapping[str, Any],
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
num_microbatches: Optional[tf.Tensor] = None,
|
num_microbatches: Optional[tf.Tensor] = None,
|
||||||
) -> type_aliases.RegistryFunctionOutput:
|
) -> type_aliases.RegistryFunctionOutput:
|
||||||
|
|
|
@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
|
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Helper functions.
|
|
||||||
# ==============================================================================
|
|
||||||
def get_nlp_on_device_embedding_model_generators():
|
def get_nlp_on_device_embedding_model_generators():
|
||||||
return {
|
return {
|
||||||
'bow1': common_test_utils.make_bow_model,
|
'bow1': common_test_utils.make_bow_model,
|
||||||
|
@ -49,30 +46,26 @@ def get_nlp_on_device_embedding_layer_registries():
|
||||||
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||||
dbl_registry.insert(
|
dbl_registry.insert(
|
||||||
tfm.nlp.layers.OnDeviceEmbedding,
|
tfm.nlp.layers.OnDeviceEmbedding,
|
||||||
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation
|
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
'embed_and_dense': dbl_registry,
|
'embed_and_dense': dbl_registry,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Main tests.
|
|
||||||
# ==============================================================================
|
|
||||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.strategy = tf.distribute.get_strategy()
|
self.strategy = tf.distribute.get_strategy()
|
||||||
|
self.using_tpu = False
|
||||||
|
|
||||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||||
# supports them for embeddings.
|
# supports them for embeddings.
|
||||||
@parameterized.product(
|
@parameterized.product(
|
||||||
input_data=get_nlp_on_device_embedding_inputs(),
|
input_data=get_nlp_on_device_embedding_inputs(),
|
||||||
scale_factor=[None, 0.5, 1.0],
|
scale_factor=[None, 0.5, 1.0],
|
||||||
model_name=list(
|
model_name=list(get_nlp_on_device_embedding_model_generators().keys()),
|
||||||
get_nlp_on_device_embedding_model_generators().keys()
|
|
||||||
),
|
|
||||||
output_dim=[2],
|
output_dim=[2],
|
||||||
layer_registry_name=list(
|
layer_registry_name=list(
|
||||||
get_nlp_on_device_embedding_layer_registries().keys()
|
get_nlp_on_device_embedding_layer_registries().keys()
|
||||||
|
@ -97,7 +90,6 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# The following are invalid test combinations and, hence, are skipped.
|
# The following are invalid test combinations and, hence, are skipped.
|
||||||
batch_size = embed_indices.shape[0]
|
batch_size = embed_indices.shape[0]
|
||||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
|
||||||
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -106,9 +98,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def embed_layer_generator(_, output_dims):
|
def embed_layer_generator(_, output_dims):
|
||||||
return tfm.nlp.layers.OnDeviceEmbedding(
|
return tfm.nlp.layers.OnDeviceEmbedding(
|
||||||
10,
|
10, *output_dims, scale_factor=scale_factor
|
||||||
*output_dims,
|
|
||||||
scale_factor=scale_factor
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = common_test_utils.get_model_from_generator(
|
model = common_test_utils.get_model_from_generator(
|
||||||
|
@ -137,7 +127,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TPUs can only run `tf.function`-decorated functions.
|
# TPUs can only run `tf.function`-decorated functions.
|
||||||
if using_tpu:
|
if self.using_tpu:
|
||||||
test_op = tf.function(test_op, autograph=False)
|
test_op = tf.function(test_op, autograph=False)
|
||||||
|
|
||||||
# Set up the device ops and run the test.
|
# Set up the device ops and run the test.
|
||||||
|
@ -145,7 +135,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
test_op, args=(embed_indices,)
|
test_op, args=(embed_indices,)
|
||||||
)
|
)
|
||||||
# TPUs return replica contexts, which must be unwrapped.
|
# TPUs return replica contexts, which must be unwrapped.
|
||||||
if using_tpu:
|
if self.using_tpu:
|
||||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
computed_norms = computed_norms.values[0]
|
computed_norms = computed_norms.values[0]
|
||||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
|
||||||
self.strategy = common_test_utils.create_tpu_strategy()
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
self.using_tpu = True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -13,12 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""A collection of type aliases used throughout the clipping library."""
|
"""A collection of type aliases used throughout the clipping library."""
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
# Tensorflow aliases.
|
# Tensorflow aliases.
|
||||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]]
|
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
|
||||||
|
|
||||||
InputTensors = PackedTensors
|
InputTensors = PackedTensors
|
||||||
|
|
||||||
|
@ -31,12 +32,12 @@ LossFn = Callable[..., tf.Tensor]
|
||||||
# Layer Registry aliases.
|
# Layer Registry aliases.
|
||||||
SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
|
SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
|
||||||
|
|
||||||
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
|
RegistryFunctionOutput = tuple[Any, OutputTensors, SquareNormFunction]
|
||||||
|
|
||||||
RegistryFunction = Callable[
|
RegistryFunction = Callable[
|
||||||
[
|
[
|
||||||
Any,
|
Any,
|
||||||
Tuple[Any, ...],
|
tuple[Any, ...],
|
||||||
Mapping[str, Any],
|
Mapping[str, Any],
|
||||||
tf.GradientTape,
|
tf.GradientTape,
|
||||||
Union[tf.Tensor, None],
|
Union[tf.Tensor, None],
|
||||||
|
@ -45,11 +46,11 @@ RegistryFunction = Callable[
|
||||||
]
|
]
|
||||||
|
|
||||||
# Clipping aliases.
|
# Clipping aliases.
|
||||||
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
|
GeneratorFunction = Optional[Callable[[Any, tuple, Mapping], tuple[Any, Any]]]
|
||||||
|
|
||||||
# Testing aliases.
|
# Testing aliases.
|
||||||
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
|
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
|
||||||
|
|
||||||
ModelGenerator = Callable[
|
ModelGenerator = Callable[
|
||||||
[LayerGenerator, List[int], List[int]], tf.keras.Model
|
[LayerGenerator, Sequence[int], Sequence[int]], tf.keras.Model
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue