Improve the readability of the fast gradient clipping library.

PiperOrigin-RevId: 566961891
This commit is contained in:
A. Unique TensorFlower 2023-09-20 07:47:12 -07:00
parent 97aaf302eb
commit c037070e50
18 changed files with 84 additions and 111 deletions

View file

@ -21,7 +21,8 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function). `compute_gradient_norms()` function).
""" """
from typing import List, Optional, Tuple from collections.abc import Sequence
from typing import Optional
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
@ -74,7 +75,7 @@ def compute_gradient_norms(
weight_batch: Optional[tf.Tensor] = None, weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[type_aliases.LossFn] = None, per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None, trainable_vars: Optional[Sequence[tf.Variable]] = None,
): ):
"""Computes the per-example loss gradient norms for given data. """Computes the per-example loss gradient norms for given data.
@ -219,7 +220,7 @@ def compute_clipped_gradients_and_outputs(
weight_batch: Optional[tf.Tensor] = None, weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None, clipping_loss: Optional[type_aliases.LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]: ) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs. """Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main

View file

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Text, Tuple from collections.abc import Mapping, Sequence
from typing import Any, Optional
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
@ -22,9 +23,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
# Helper functions and classes.
# ==============================================================================
class DoubleDense(tf.keras.layers.Layer): class DoubleDense(tf.keras.layers.Layer):
"""Generates two dense layers nested together.""" """Generates two dense layers nested together."""
@ -40,8 +38,8 @@ class DoubleDense(tf.keras.layers.Layer):
def double_dense_layer_computation( def double_dense_layer_computation(
layer_instance: tf.keras.layers.Layer, layer_instance: tf.keras.layers.Layer,
input_args: Tuple[Any, ...], input_args: Sequence[Any],
input_kwargs: Dict[Text, Any], input_kwargs: Mapping[str, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[int], num_microbatches: Optional[int],
) -> type_aliases.RegistryFunctionOutput: ) -> type_aliases.RegistryFunctionOutput:
@ -61,9 +59,6 @@ def double_dense_layer_computation(
return [vars1, vars2], outputs, sqr_norm_fn return [vars1, vars2], outputs, sqr_norm_fn
# ==============================================================================
# Main tests.
# ==============================================================================
class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase): class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product( @parameterized.product(

View file

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
"""A collection of common utility functions for unit testing.""" """A collection of common utility functions for unit testing."""
from typing import Callable, List, Optional, Tuple from collections.abc import Callable, MutableSequence, Sequence
from typing import Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -108,8 +109,8 @@ def compute_true_gradient_norms(
def get_model_from_generator( def get_model_from_generator(
model_generator: type_aliases.ModelGenerator, model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: Sequence[int],
output_dims: List[int], output_dims: Sequence[int],
is_eager: bool, is_eager: bool,
) -> tf.keras.Model: ) -> tf.keras.Model:
"""Creates a simple model from input specifications.""" """Creates a simple model from input specifications."""
@ -171,8 +172,8 @@ def get_computed_and_true_norms_from_model(
def get_computed_and_true_norms( def get_computed_and_true_norms(
model_generator: type_aliases.ModelGenerator, model_generator: type_aliases.ModelGenerator,
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: Sequence[int],
output_dims: List[int], output_dims: Sequence[int],
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
num_microbatches: Optional[int], num_microbatches: Optional[int],
is_eager: bool, is_eager: bool,
@ -181,7 +182,7 @@ def get_computed_and_true_norms(
rng_seed: int = 777, rng_seed: int = 777,
registry: layer_registry.LayerRegistry = None, registry: layer_registry.LayerRegistry = None,
partial: bool = False, partial: bool = False,
) -> Tuple[tf.Tensor, tf.Tensor]: ) -> tuple[tf.Tensor, tf.Tensor]:
"""Obtains the true and computed gradient norms for a model and batch input. """Obtains the true and computed gradient norms for a model and batch input.
Helpful testing wrapper function used to avoid code duplication. Helpful testing wrapper function used to avoid code duplication.
@ -245,8 +246,8 @@ def reshape_and_sum(tensor: tf.Tensor) -> tf.Tensor:
# ============================================================================== # ==============================================================================
def make_one_layer_functional_model( def make_one_layer_functional_model(
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: Sequence[int],
output_dims: List[int], output_dims: Sequence[int],
) -> tf.keras.Model: ) -> tf.keras.Model:
"""Creates a 1-layer sequential model.""" """Creates a 1-layer sequential model."""
inputs = tf.keras.Input(shape=input_dims) inputs = tf.keras.Input(shape=input_dims)
@ -258,8 +259,8 @@ def make_one_layer_functional_model(
def make_two_layer_functional_model( def make_two_layer_functional_model(
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: Sequence[int],
output_dims: List[int], output_dims: Sequence[int],
) -> tf.keras.Model: ) -> tf.keras.Model:
"""Creates a 2-layer sequential model.""" """Creates a 2-layer sequential model."""
inputs = tf.keras.Input(shape=input_dims) inputs = tf.keras.Input(shape=input_dims)
@ -272,8 +273,8 @@ def make_two_layer_functional_model(
def make_two_tower_model( def make_two_tower_model(
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: Sequence[int],
output_dims: List[int], output_dims: Sequence[int],
) -> tf.keras.Model: ) -> tf.keras.Model:
"""Creates a 2-layer 2-input functional model.""" """Creates a 2-layer 2-input functional model."""
inputs1 = tf.keras.Input(shape=input_dims) inputs1 = tf.keras.Input(shape=input_dims)
@ -290,8 +291,8 @@ def make_two_tower_model(
def make_bow_model( def make_bow_model(
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: Sequence[int],
output_dims: List[int], output_dims: Sequence[int],
) -> tf.keras.Model: ) -> tf.keras.Model:
"""Creates a simple embedding bow model.""" """Creates a simple embedding bow model."""
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32) inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
@ -315,8 +316,8 @@ def make_bow_model(
def make_dense_bow_model( def make_dense_bow_model(
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: Sequence[int],
output_dims: List[int], output_dims: Sequence[int],
) -> tf.keras.Model: ) -> tf.keras.Model:
"""Creates an embedding bow model with a `Dense` layer.""" """Creates an embedding bow model with a `Dense` layer."""
inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32) inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32)
@ -341,8 +342,8 @@ def make_dense_bow_model(
def make_weighted_bow_model( def make_weighted_bow_model(
layer_generator: type_aliases.LayerGenerator, layer_generator: type_aliases.LayerGenerator,
input_dims: List[int], input_dims: MutableSequence[int],
output_dims: List[int], output_dims: MutableSequence[int],
) -> tf.keras.Model: ) -> tf.keras.Model:
"""Creates a weighted embedding bow model.""" """Creates a weighted embedding bow model."""
# NOTE: This model only accepts dense input tensors. # NOTE: This model only accepts dense input tensors.
@ -353,10 +354,9 @@ def make_weighted_bow_model(
emb_layer = layer_generator(input_dims, output_dims) emb_layer = layer_generator(input_dims, output_dims)
if len(output_dims) != 1: if len(output_dims) != 1:
raise ValueError('Expected `output_dims` to be of size 1.') raise ValueError('Expected `output_dims` to be of size 1.')
output_dim = output_dims[0]
feature_embs = emb_layer(inputs) feature_embs = emb_layer(inputs)
# Use deterministic weights to avoid seeding issues on TPUs. # Use deterministic weights to avoid seeding issues on TPUs.
feature_shape = input_dims + [output_dim] feature_shape = input_dims + output_dims
feature_weights = tf.expand_dims( feature_weights = tf.expand_dims(
tf.reshape( tf.reshape(
tf.range(np.product(feature_shape), dtype=tf.float32), tf.range(np.product(feature_shape), dtype=tf.float32),

View file

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms.""" """Utility functions that help in the computation of per-example gradient norms."""
from typing import Any, List, Optional, Set, Tuple from collections.abc import Sequence, Set
from typing import Any, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
@ -36,7 +37,7 @@ def model_forward_pass(
input_model: tf.keras.Model, input_model: tf.keras.Model,
inputs: type_aliases.PackedTensors, inputs: type_aliases.PackedTensors,
generator_fn: type_aliases.GeneratorFunction = None, generator_fn: type_aliases.GeneratorFunction = None,
) -> Tuple[type_aliases.PackedTensors, List[Any]]: ) -> tuple[type_aliases.PackedTensors, Sequence[Any]]:
"""Does a forward pass of a model and returns useful intermediates. """Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the NOTE: the graph traversal algorithm is an adaptation of the logic in the
@ -149,7 +150,7 @@ def add_aggregate_noise(
batch_size: tf.Tensor, batch_size: tf.Tensor,
l2_norm_clip: float, l2_norm_clip: float,
noise_multiplier: float, noise_multiplier: float,
) -> List[tf.Tensor]: ) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients. """Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the The magnitude of the noise depends on the aggregation strategy of the

View file

@ -16,7 +16,6 @@ from typing import Any
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
@ -90,7 +89,7 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
num_dims = 3 num_dims = 3
num_inputs = 1 if input_packing_type is None else 2 num_inputs = 1 if input_packing_type is None else 2
num_outputs = 1 if output_packing_type is None else 2 num_outputs = 1 if output_packing_type is None else 2
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)] sample_inputs = [tf.keras.Input((num_dims,)) for _ in range(num_inputs)]
temp_sum = tf.stack(sample_inputs, axis=0) temp_sum = tf.stack(sample_inputs, axis=0)
sample_outputs = [ sample_outputs = [
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs) tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)

View file

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
"""Fast clipping function for `tf.keras.layers.Dense`.""" """Fast clipping function for `tf.keras.layers.Dense`."""
from typing import Any, Mapping, Tuple, Union from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def dense_layer_computation( def dense_layer_computation(
layer_instance: tf.keras.layers.Dense, layer_instance: tf.keras.layers.Dense,
input_args: Tuple[Any, ...], input_args: Sequence[Any],
input_kwargs: Mapping[str, Any], input_kwargs: Mapping[str, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None, num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput: ) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`. """Registry function for `tf.keras.layers.Dense`.

View file

@ -19,11 +19,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_dense_layer_generators(): def get_dense_layer_generators():
def sigmoid_dense_layer(units): def sigmoid_dense_layer(units):
return tf.keras.layers.Dense(units, activation='sigmoid') return tf.keras.layers.Dense(units, activation='sigmoid')
@ -50,14 +46,12 @@ def get_dense_layer_registries():
} }
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase): class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
@parameterized.product( @parameterized.product(
model_name=list(get_dense_model_generators().keys()), model_name=list(get_dense_model_generators().keys()),
@ -112,8 +106,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
) )
# TPUs can only run `tf.function`-decorated functions. # TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy) if self.using_tpu:
if using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False) test_op = tf.function(test_op, jit_compile=True, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion. # TPUs use lower precision than CPUs, so we relax our criterion.
@ -127,8 +120,8 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# which is a reasonable level of error for computing gradient norms. # which is a reasonable level of error for computing gradient norms.
# Other trials also give an absolute (resp. relative) error of around # Other trials also give an absolute (resp. relative) error of around
# 0.05 (resp. 0.0015). # 0.05 (resp. 0.0015).
rtol = 1e-2 if using_tpu else 1e-3 rtol = 1e-2 if self.using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2 atol = 1e-1 if self.using_tpu else 1e-2
for x_batch, weight_batch in zip(x_batches, weight_batches): for x_batch, weight_batch in zip(x_batches, weight_batches):
batch_size = x_batch.shape[0] batch_size = x_batch.shape[0]
@ -139,7 +132,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
test_op, args=(x_batch, weight_batch) test_op, args=(x_batch, weight_batch)
) )
# TPUs return replica contexts, which must be unwrapped. # TPUs return replica contexts, which must be unwrapped.
if using_tpu: if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms) common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms) common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0] computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(dense_test.GradNormTest): class GradNormTpuTest(dense_test.GradNormTest):
def setUp(self): def setUp(self):
super().setUp() super(dense_test.GradNormTest, self).setUp()
self.strategy = ctu.create_tpu_strategy() self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
"""Fast clipping function for `tf.keras.layers.Embedding`.""" """Fast clipping function for `tf.keras.layers.Embedding`."""
from typing import Any, Mapping, Tuple, Union from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
@ -21,10 +22,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
def embedding_layer_computation( def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding, layer_instance: tf.keras.layers.Embedding,
input_args: Tuple[Any, ...], input_args: Sequence[Any],
input_kwargs: Mapping[str, Any], input_kwargs: Mapping[str, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None, num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput: ) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`. """Registry function for `tf.keras.layers.Embedding`.

View file

@ -20,9 +20,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_embedding_model_generators(): def get_embedding_model_generators():
return { return {
'bow1': common_test_utils.make_bow_model, 'bow1': common_test_utils.make_bow_model,
@ -59,14 +56,12 @@ def get_embedding_layer_registries():
} }
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase): class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings. # supports them for embeddings.
@ -101,13 +96,12 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# The following are invalid test combinations and, hence, are skipped. # The following are invalid test combinations and, hence, are skipped.
batch_size = embed_indices.shape[0] batch_size = embed_indices.shape[0]
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if ( if (
(num_microbatches is not None and batch_size % num_microbatches != 0) (num_microbatches is not None and batch_size % num_microbatches != 0)
or (model_name == 'weighted_bow1' and is_ragged) or (model_name == 'weighted_bow1' and is_ragged)
or ( or (
# Current clipping ops do not have corresponding TPU kernels. # Current clipping ops do not have corresponding TPU kernels.
using_tpu self.using_tpu
and is_ragged and is_ragged
) )
): ):
@ -139,7 +133,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
) )
# TPUs can only run `tf.function`-decorated functions. # TPUs can only run `tf.function`-decorated functions.
if using_tpu: if self.using_tpu:
test_op = tf.function(test_op, autograph=False) test_op = tf.function(test_op, autograph=False)
# Set up the device ops and run the test. # Set up the device ops and run the test.
@ -147,7 +141,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
test_op, args=(embed_indices,) test_op, args=(embed_indices,)
) )
# TPUs return replica contexts, which must be unwrapped. # TPUs return replica contexts, which must be unwrapped.
if using_tpu: if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms) common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms) common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0] computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(embedding_test.GradNormTest): class GradNormTpuTest(embedding_test.GradNormTest):
def setUp(self): def setUp(self):
super().setUp() super(embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy() self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -13,21 +13,19 @@
# limitations under the License. # limitations under the License.
"""Fast clipping function for `tf.keras.layers.LayerNormalization`.""" """Fast clipping function for `tf.keras.layers.LayerNormalization`."""
from typing import Any, Mapping, Tuple, Union from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
# ==============================================================================
# Supported Keras layers
# ==============================================================================
def layer_normalization_computation( def layer_normalization_computation(
layer_instance: tf.keras.layers.LayerNormalization, layer_instance: tf.keras.layers.LayerNormalization,
input_args: Tuple[Any, ...], input_args: Sequence[Any],
input_kwargs: Mapping[str, Any], input_kwargs: Mapping[str, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None, num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput: ) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.LayerNormalization`. """Registry function for `tf.keras.layers.LayerNormalization`.

View file

@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_layer_norm_layer_generators(): def get_layer_norm_layer_generators():
return { return {
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x), 'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
@ -73,14 +70,12 @@ def get_layer_norm_registries():
} }
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase): class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
@parameterized.product( @parameterized.product(
model_name=list(get_layer_norm_model_generators().keys()), model_name=list(get_layer_norm_model_generators().keys()),
@ -130,14 +125,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
) )
# TPUs can only run `tf.function`-decorated functions. # TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy) if self.using_tpu:
if using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False) test_op = tf.function(test_op, jit_compile=True, autograph=False)
# TPUs use lower precision than CPUs, so we relax our criterion (see # TPUs use lower precision than CPUs, so we relax our criterion (see
# `dense_test.py` for additional discussions). # `dense_test.py` for additional discussions).
rtol = 1e-2 if using_tpu else 1e-3 rtol = 1e-2 if self.using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2 atol = 1e-1 if self.using_tpu else 1e-2
# Each batched input is a reshape of a `tf.range()` call. # Each batched input is a reshape of a `tf.range()` call.
batch_size = 2 batch_size = 2
@ -148,7 +142,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# Set up the device ops and run the test. # Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,)) computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
# TPUs return replica contexts, which must be unwrapped. # TPUs return replica contexts, which must be unwrapped.
if using_tpu: if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms) common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms) common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0] computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(layer_normalization_test.GradNormTest): class GradNormTpuTest(layer_normalization_test.GradNormTest):
def setUp(self): def setUp(self):
super().setUp() super(layer_normalization_test.GradNormTest, self).setUp()
self.strategy = ctu.create_tpu_strategy() self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`.""" """Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`."""
from typing import Any, Dict, Optional, Tuple from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils
@ -21,8 +22,8 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
def nlp_on_device_embedding_layer_computation( def nlp_on_device_embedding_layer_computation(
layer_instance: tf.keras.layers.Layer, layer_instance: tf.keras.layers.Layer,
input_args: Tuple[Any, ...], input_args: Sequence[Any],
input_kwargs: Dict[str, Any], input_kwargs: Mapping[str, Any],
tape: tf.GradientTape, tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None, num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput: ) -> type_aliases.RegistryFunctionOutput:

View file

@ -21,9 +21,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding
# ==============================================================================
# Helper functions.
# ==============================================================================
def get_nlp_on_device_embedding_model_generators(): def get_nlp_on_device_embedding_model_generators():
return { return {
'bow1': common_test_utils.make_bow_model, 'bow1': common_test_utils.make_bow_model,
@ -49,30 +46,26 @@ def get_nlp_on_device_embedding_layer_registries():
dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation) dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
dbl_registry.insert( dbl_registry.insert(
tfm.nlp.layers.OnDeviceEmbedding, tfm.nlp.layers.OnDeviceEmbedding,
nlp_on_device_embedding.nlp_on_device_embedding_layer_computation nlp_on_device_embedding.nlp_on_device_embedding_layer_computation,
) )
return { return {
'embed_and_dense': dbl_registry, 'embed_and_dense': dbl_registry,
} }
# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase): class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.using_tpu = False
# TODO(weiweikong): Test sparse input tensors when the GitHub CI environment # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment
# supports them for embeddings. # supports them for embeddings.
@parameterized.product( @parameterized.product(
input_data=get_nlp_on_device_embedding_inputs(), input_data=get_nlp_on_device_embedding_inputs(),
scale_factor=[None, 0.5, 1.0], scale_factor=[None, 0.5, 1.0],
model_name=list( model_name=list(get_nlp_on_device_embedding_model_generators().keys()),
get_nlp_on_device_embedding_model_generators().keys()
),
output_dim=[2], output_dim=[2],
layer_registry_name=list( layer_registry_name=list(
get_nlp_on_device_embedding_layer_registries().keys() get_nlp_on_device_embedding_layer_registries().keys()
@ -97,7 +90,6 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
# The following are invalid test combinations and, hence, are skipped. # The following are invalid test combinations and, hence, are skipped.
batch_size = embed_indices.shape[0] batch_size = embed_indices.shape[0]
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if num_microbatches is not None and batch_size % num_microbatches != 0: if num_microbatches is not None and batch_size % num_microbatches != 0:
return return
@ -106,9 +98,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
def embed_layer_generator(_, output_dims): def embed_layer_generator(_, output_dims):
return tfm.nlp.layers.OnDeviceEmbedding( return tfm.nlp.layers.OnDeviceEmbedding(
10, 10, *output_dims, scale_factor=scale_factor
*output_dims,
scale_factor=scale_factor
) )
model = common_test_utils.get_model_from_generator( model = common_test_utils.get_model_from_generator(
@ -137,7 +127,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
) )
# TPUs can only run `tf.function`-decorated functions. # TPUs can only run `tf.function`-decorated functions.
if using_tpu: if self.using_tpu:
test_op = tf.function(test_op, autograph=False) test_op = tf.function(test_op, autograph=False)
# Set up the device ops and run the test. # Set up the device ops and run the test.
@ -145,7 +135,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
test_op, args=(embed_indices,) test_op, args=(embed_indices,)
) )
# TPUs return replica contexts, which must be unwrapped. # TPUs return replica contexts, which must be unwrapped.
if using_tpu: if self.using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms) common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms) common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0] computed_norms = computed_norms.values[0]

View file

@ -20,9 +20,10 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
def setUp(self): def setUp(self):
super().setUp() super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy() self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])
self.using_tpu = True
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
"""A collection of type aliases used throughout the clipping library.""" """A collection of type aliases used throughout the clipping library."""
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, Optional, Union
import tensorflow as tf import tensorflow as tf
# Tensorflow aliases. # Tensorflow aliases.
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]] PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
InputTensors = PackedTensors InputTensors = PackedTensors
@ -31,12 +32,12 @@ LossFn = Callable[..., tf.Tensor]
# Layer Registry aliases. # Layer Registry aliases.
SquareNormFunction = Callable[[OutputTensors], tf.Tensor] SquareNormFunction = Callable[[OutputTensors], tf.Tensor]
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction] RegistryFunctionOutput = tuple[Any, OutputTensors, SquareNormFunction]
RegistryFunction = Callable[ RegistryFunction = Callable[
[ [
Any, Any,
Tuple[Any, ...], tuple[Any, ...],
Mapping[str, Any], Mapping[str, Any],
tf.GradientTape, tf.GradientTape,
Union[tf.Tensor, None], Union[tf.Tensor, None],
@ -45,11 +46,11 @@ RegistryFunction = Callable[
] ]
# Clipping aliases. # Clipping aliases.
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]] GeneratorFunction = Optional[Callable[[Any, tuple, Mapping], tuple[Any, Any]]]
# Testing aliases. # Testing aliases.
LayerGenerator = Callable[[int, int], tf.keras.layers.Layer] LayerGenerator = Callable[[int, int], tf.keras.layers.Layer]
ModelGenerator = Callable[ ModelGenerator = Callable[
[LayerGenerator, List[int], List[int]], tf.keras.Model [LayerGenerator, Sequence[int], Sequence[int]], tf.keras.Model
] ]