Improve the readability of the fast gradient clipping library.
PiperOrigin-RevId: 566961891
This commit is contained in:
parent
97aaf302eb
commit
c037070e50
18 changed files with 84 additions and 111 deletions
|
@ -21,7 +21,8 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
|||
`compute_gradient_norms()` function).
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||
|
@ -74,7 +75,7 @@ def compute_gradient_norms(
|
|||
weight_batch: Optional[tf.Tensor] = None,
|
||||
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
trainable_vars: Optional[List[tf.Variable]] = None,
|
||||
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||
):
|
||||
"""Computes the per-example loss gradient norms for given data.
|
||||
|
||||
|
@ -219,7 +220,7 @@ def compute_clipped_gradients_and_outputs(
|
|||
weight_batch: Optional[tf.Tensor] = None,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
clipping_loss: Optional[type_aliases.LossFn] = None,
|
||||
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
|
||||
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
|
||||
"""Computes the per-example clipped loss gradient and other useful outputs.
|
||||
|
||||
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Text, Tuple
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
@ -22,9 +23,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions and classes.
|
||||
# ==============================================================================
|
||||
class DoubleDense(tf.keras.layers.Layer):
|
||||
"""Generates two dense layers nested together."""
|
||||
|
||||
|
@ -40,8 +38,8 @@ class DoubleDense(tf.keras.layers.Layer):
|
|||
|
||||
def double_dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Layer,
|
||||
input_args: Tuple[Any, ...],
|
||||
input_kwargs: Dict[Text, Any],
|
||||
input_args: Sequence[Any],
|
||||
input_kwargs: Mapping[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[int],
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
|
@ -61,9 +59,6 @@ def double_dense_layer_computation(
|
|||
return [vars1, vars2], outputs, sqr_norm_fn
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
"""A collection of common utility functions for unit testing."""
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from collections.abc import Callable, MutableSequence, Sequence
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -108,8 +109,8 @@ def compute_true_gradient_norms(
|
|||
def get_model_from_generator(
|
||||
model_generator: type_aliases.ModelGenerator,
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: Sequence[int],
|
||||
output_dims: Sequence[int],
|
||||
is_eager: bool,
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a simple model from input specifications."""
|
||||
|
@ -171,8 +172,8 @@ def get_computed_and_true_norms_from_model(
|
|||
def get_computed_and_true_norms(
|
||||
model_generator: type_aliases.ModelGenerator,
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: Sequence[int],
|
||||
output_dims: Sequence[int],
|
||||
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
|
||||
num_microbatches: Optional[int],
|
||||
is_eager: bool,
|
||||
|
@ -181,7 +182,7 @@ def get_computed_and_true_norms(
|
|||
rng_seed: int = 777,
|
||||
registry: layer_registry.LayerRegistry = None,
|
||||
partial: bool = False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
) -> tuple[tf.Tensor, tf.Tensor]:
|
||||
"""Obtains the true and computed gradient norms for a model and batch input.
|
||||
|
||||
Helpful testing wrapper function used to avoid code duplication.
|
||||
|
@ -245,8 +246,8 @@ def reshape_and_sum(tensor: tf.Tensor) -> tf.Tensor:
|
|||
# ==============================================================================
|
||||
def make_one_layer_functional_model(
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: Sequence[int],
|
||||
output_dims: Sequence[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a 1-layer sequential model."""
|
||||
inputs = tf.keras.Input(shape=input_dims)
|
||||
|
@ -258,8 +259,8 @@ def make_one_layer_functional_model(
|
|||
|
||||
def make_two_layer_functional_model(
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: Sequence[int],
|
||||
output_dims: Sequence[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a 2-layer sequential model."""
|
||||
inputs = tf.keras.Input(shape=input_dims)
|
||||
|
@ -272,8 +273,8 @@ def make_two_layer_functional_model(
|
|||
|
||||
def make_two_tower_model(
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: Sequence[int],
|
||||
output_dims: Sequence[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a 2-layer 2-input functional model."""
|
||||
inputs1 = tf.keras.Input(shape=input_dims)
|
||||
|
@ -290,8 +291,8 @@ def make_two_tower_model(
|
|||
|
||||
def make_bow_model(
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: Sequence[int],
|
||||
output_dims: Sequence[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a simple embedding bow model."""
|
||||
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||
|
@ -315,8 +316,8 @@ def make_bow_model(
|
|||
|
||||
def make_dense_bow_model(
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: Sequence[int],
|
||||
output_dims: Sequence[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates an embedding bow model with a `Dense` layer."""
|
||||
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
|
||||
|
@ -341,8 +342,8 @@ def make_dense_bow_model(
|
|||
|
||||
def make_weighted_bow_model(
|
||||
layer_generator: type_aliases.LayerGenerator,
|
||||
input_dims: List[int],
|
||||
output_dims: List[int],
|
||||
input_dims: MutableSequence[int],
|
||||
output_dims: MutableSequence[int],
|
||||
) -> tf.keras.Model:
|
||||
"""Creates a weighted embedding bow model."""
|
||||
# NOTE: This model only accepts dense input tensors.
|
||||
|
@ -353,10 +354,9 @@ def make_weighted_bow_model(
|
|||
emb_layer = layer_generator(input_dims, output_dims)
|
||||
if len(output_dims) != 1:
|
||||
raise ValueError('Expected `output_dims` to be of size 1.')
|
||||
output_dim = output_dims[0]
|
||||
feature_embs = emb_layer(inputs)
|
||||
# Use deterministic weights to avoid seeding issues on TPUs.
|
||||
feature_shape = input_dims + [output_dim]
|
||||
feature_shape = input_dims + output_dims
|
||||
feature_weights = tf.expand_dims(
|
||||
tf.reshape(
|
||||
tf.range(np.product(feature_shape), dtype=tf.float32),
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
from typing import Any, List, Optional, Set, Tuple
|
||||
from collections.abc import Sequence, Set
|
||||
from typing import Any, Optional
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
@ -36,7 +37,7 @@ def model_forward_pass(
|
|||
input_model: tf.keras.Model,
|
||||
inputs: type_aliases.PackedTensors,
|
||||
generator_fn: type_aliases.GeneratorFunction = None,
|
||||
) -> Tuple[type_aliases.PackedTensors, List[Any]]:
|
||||
) -> tuple[type_aliases.PackedTensors, Sequence[Any]]:
|
||||
"""Does a forward pass of a model and returns useful intermediates.
|
||||
|
||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||
|
@ -149,7 +150,7 @@ def add_aggregate_noise(
|
|||
batch_size: tf.Tensor,
|
||||
l2_norm_clip: float,
|
||||
noise_multiplier: float,
|
||||
) -> List[tf.Tensor]:
|
||||
) -> Sequence[tf.Tensor]:
|
||||
"""Adds noise to a collection of clipped gradients.
|
||||
|
||||
The magnitude of the noise depends on the aggregation strategy of the
|
||||
|
|
|
@ -16,7 +16,6 @@ from typing import Any
|
|||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
||||
|
||||
|
@ -90,7 +89,7 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_dims = 3
|
||||
num_inputs = 1 if input_packing_type is None else 2
|
||||
num_outputs = 1 if output_packing_type is None else 2
|
||||
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
|
||||
sample_inputs = [tf.keras.Input((num_dims,)) for _ in range(num_inputs)]
|
||||
temp_sum = tf.stack(sample_inputs, axis=0)
|
||||
sample_outputs = [
|
||||
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
"""Fast clipping function for `tf.keras.layers.Dense`."""
|
||||
|
||||
from typing import Any, Mapping, Tuple, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
|||
|
||||
def dense_layer_computation(
|
||||
layer_instance: tf.keras.layers.Dense,
|
||||
input_args: Tuple[Any, ...],
|
||||
input_args: Sequence[Any],
|
||||
input_kwargs: Mapping[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Union[tf.Tensor, None] = None,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Dense`.
|
||||
|
||||
|
|
|
@ -19,11 +19,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions.
|
||||
# ==============================================================================
|
||||
def get_dense_layer_generators():
|
||||
|
||||
def sigmoid_dense_layer(units):
|
||||
return tf.keras.layers.Dense(units, activation='sigmoid')
|
||||
|
||||
|
@ -50,14 +46,12 @@ def get_dense_layer_registries():
|
|||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
self.using_tpu = False
|
||||
|
||||
@parameterized.product(
|
||||
model_name=list(get_dense_model_generators().keys()),
|
||||
|
@ -112,8 +106,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||
|
||||
# TPUs use lower precision than CPUs, so we relax our criterion.
|
||||
|
@ -127,8 +120,8 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# which is a reasonable level of error for computing gradient norms.
|
||||
# Other trials also give an absolute (resp. relative) error of around
|
||||
# 0.05 (resp. 0.0015).
|
||||
rtol = 1e-2 if using_tpu else 1e-3
|
||||
atol = 1e-1 if using_tpu else 1e-2
|
||||
rtol = 1e-2 if self.using_tpu else 1e-3
|
||||
atol = 1e-1 if self.using_tpu else 1e-2
|
||||
|
||||
for x_batch, weight_batch in zip(x_batches, weight_batches):
|
||||
batch_size = x_batch.shape[0]
|
||||
|
@ -139,7 +132,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
test_op, args=(x_batch, weight_batch)
|
||||
)
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
class GradNormTpuTest(dense_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
super(dense_test.GradNormTest, self).setUp()
|
||||
self.strategy = ctu.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
self.using_tpu = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
"""Fast clipping function for `tf.keras.layers.Embedding`."""
|
||||
|
||||
from typing import Any, Mapping, Tuple, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||
|
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
|
||||
def embedding_layer_computation(
|
||||
layer_instance: tf.keras.layers.Embedding,
|
||||
input_args: Tuple[Any, ...],
|
||||
input_args: Sequence[Any],
|
||||
input_kwargs: Mapping[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Union[tf.Tensor, None] = None,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.Embedding`.
|
||||
|
||||
|
|
|
@ -20,9 +20,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions.
|
||||
# ==============================================================================
|
||||
def get_embedding_model_generators():
|
||||
return {
|
||||
'bow1': common_test_utils.make_bow_model,
|
||||
|
@ -59,14 +56,12 @@ def get_embedding_layer_registries():
|
|||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
self.using_tpu = False
|
||||
|
||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||
# supports them for embeddings.
|
||||
|
@ -101,13 +96,12 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
# The following are invalid test combinations and, hence, are skipped.
|
||||
batch_size = embed_indices.shape[0]
|
||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||
if (
|
||||
(num_microbatches is not None and batch_size % num_microbatches != 0)
|
||||
or (model_name == 'weighted_bow1' and is_ragged)
|
||||
or (
|
||||
# Current clipping ops do not have corresponding TPU kernels.
|
||||
using_tpu
|
||||
self.using_tpu
|
||||
and is_ragged
|
||||
)
|
||||
):
|
||||
|
@ -139,7 +133,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
test_op = tf.function(test_op, autograph=False)
|
||||
|
||||
# Set up the device ops and run the test.
|
||||
|
@ -147,7 +141,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
test_op, args=(embed_indices,)
|
||||
)
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
class GradNormTpuTest(embedding_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
super(embedding_test.GradNormTest, self).setUp()
|
||||
self.strategy = common_test_utils.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
self.using_tpu = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -13,21 +13,19 @@
|
|||
# limitations under the License.
|
||||
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""
|
||||
|
||||
from typing import Any, Mapping, Tuple, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Supported Keras layers
|
||||
# ==============================================================================
|
||||
def layer_normalization_computation(
|
||||
layer_instance: tf.keras.layers.LayerNormalization,
|
||||
input_args: Tuple[Any, ...],
|
||||
input_args: Sequence[Any],
|
||||
input_kwargs: Mapping[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Union[tf.Tensor, None] = None,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
"""Registry function for `tf.keras.layers.LayerNormalization`.
|
||||
|
||||
|
|
|
@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions.
|
||||
# ==============================================================================
|
||||
def get_layer_norm_layer_generators():
|
||||
return {
|
||||
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
|
||||
|
@ -73,14 +70,12 @@ def get_layer_norm_registries():
|
|||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
self.using_tpu = False
|
||||
|
||||
@parameterized.product(
|
||||
model_name=list(get_layer_norm_model_generators().keys()),
|
||||
|
@ -130,14 +125,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
test_op = tf.function(test_op, jit_compile=True, autograph=False)
|
||||
|
||||
# TPUs use lower precision than CPUs, so we relax our criterion (see
|
||||
# `dense_test.py` for additional discussions).
|
||||
rtol = 1e-2 if using_tpu else 1e-3
|
||||
atol = 1e-1 if using_tpu else 1e-2
|
||||
rtol = 1e-2 if self.using_tpu else 1e-3
|
||||
atol = 1e-1 if self.using_tpu else 1e-2
|
||||
|
||||
# Each batched input is a reshape of a `tf.range()` call.
|
||||
batch_size = 2
|
||||
|
@ -148,7 +142,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# Set up the device ops and run the test.
|
||||
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
class GradNormTpuTest(layer_normalization_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
super(layer_normalization_test.GradNormTest, self).setUp()
|
||||
self.strategy = ctu.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
self.using_tpu = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
|
||||
|
@ -21,8 +22,8 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
|
||||
def nlp_on_device_embedding_layer_computation(
|
||||
layer_instance: tf.keras.layers.Layer,
|
||||
input_args: Tuple[Any, ...],
|
||||
input_kwargs: Dict[str, Any],
|
||||
input_args: Sequence[Any],
|
||||
input_kwargs: Mapping[str, Any],
|
||||
tape: tf.GradientTape,
|
||||
num_microbatches: Optional[tf.Tensor] = None,
|
||||
) -> type_aliases.RegistryFunctionOutput:
|
||||
|
|
|
@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions.
|
||||
# ==============================================================================
|
||||
def get_nlp_on_device_embedding_model_generators():
|
||||
return {
|
||||
'bow1': common_test_utils.make_bow_model,
|
||||
|
@ -49,30 +46,26 @@ def get_nlp_on_device_embedding_layer_registries():
|
|||
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
|
||||
dbl_registry.insert(
|
||||
tfm.nlp.layers.OnDeviceEmbedding,
|
||||
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation
|
||||
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation,
|
||||
)
|
||||
return {
|
||||
'embed_and_dense': dbl_registry,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
self.using_tpu = False
|
||||
|
||||
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
|
||||
# supports them for embeddings.
|
||||
@parameterized.product(
|
||||
input_data=get_nlp_on_device_embedding_inputs(),
|
||||
scale_factor=[None, 0.5, 1.0],
|
||||
model_name=list(
|
||||
get_nlp_on_device_embedding_model_generators().keys()
|
||||
),
|
||||
model_name=list(get_nlp_on_device_embedding_model_generators().keys()),
|
||||
output_dim=[2],
|
||||
layer_registry_name=list(
|
||||
get_nlp_on_device_embedding_layer_registries().keys()
|
||||
|
@ -97,7 +90,6 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
# The following are invalid test combinations and, hence, are skipped.
|
||||
batch_size = embed_indices.shape[0]
|
||||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
|
||||
if num_microbatches is not None and batch_size % num_microbatches != 0:
|
||||
return
|
||||
|
||||
|
@ -106,9 +98,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def embed_layer_generator(_, output_dims):
|
||||
return tfm.nlp.layers.OnDeviceEmbedding(
|
||||
10,
|
||||
*output_dims,
|
||||
scale_factor=scale_factor
|
||||
10, *output_dims, scale_factor=scale_factor
|
||||
)
|
||||
|
||||
model = common_test_utils.get_model_from_generator(
|
||||
|
@ -137,7 +127,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
|
||||
# TPUs can only run `tf.function`-decorated functions.
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
test_op = tf.function(test_op, autograph=False)
|
||||
|
||||
# Set up the device ops and run the test.
|
||||
|
@ -145,7 +135,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
|||
test_op, args=(embed_indices,)
|
||||
)
|
||||
# TPUs return replica contexts, which must be unwrapped.
|
||||
if using_tpu:
|
||||
if self.using_tpu:
|
||||
common_test_utils.assert_replica_values_are_close(self, computed_norms)
|
||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||
computed_norms = computed_norms.values[0]
|
||||
|
|
|
@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
|||
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
|
||||
self.strategy = common_test_utils.create_tpu_strategy()
|
||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||
self.using_tpu = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -13,12 +13,13 @@
|
|||
# limitations under the License.
|
||||
"""A collection of type aliases used throughout the clipping library."""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# Tensorflow aliases.
|
||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]]
|
||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
|
||||
|
||||
InputTensors = PackedTensors
|
||||
|
||||
|
@ -31,12 +32,12 @@ LossFn = Callable[..., tf.Tensor]
|
|||
# Layer Registry aliases.
|
||||
SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
|
||||
|
||||
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]
|
||||
RegistryFunctionOutput = tuple[Any, OutputTensors, SquareNormFunction]
|
||||
|
||||
RegistryFunction = Callable[
|
||||
[
|
||||
Any,
|
||||
Tuple[Any, ...],
|
||||
tuple[Any, ...],
|
||||
Mapping[str, Any],
|
||||
tf.GradientTape,
|
||||
Union[tf.Tensor, None],
|
||||
|
@ -45,11 +46,11 @@ RegistryFunction = Callable[
|
|||
]
|
||||
|
||||
# Clipping aliases.
|
||||
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
|
||||
GeneratorFunction = Optional[Callable[[Any, tuple, Mapping], tuple[Any, Any]]]
|
||||
|
||||
# Testing aliases.
|
||||
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
|
||||
|
||||
ModelGenerator = Callable[
|
||||
[LayerGenerator, List[int], List[int]], tf.keras.Model
|
||||
[LayerGenerator, Sequence[int], Sequence[int]], tf.keras.Model
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue