diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py index c680407..e99698b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py @@ -15,7 +15,9 @@ from typing import Callable, List, Optional, Tuple, Union +import numpy as np import tensorflow as tf +import tensorflow.compat.v2 as tf_compat from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -24,6 +26,23 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases # ============================================================================== # Helper functions # ============================================================================== +def create_tpu_strategy(): + """Initializes a TPU environment.""" + # Done to avoid transferring data between CPUs and TPUs. + tf_compat.config.set_soft_device_placement(False) + resolver = tf_compat.distribute.cluster_resolver.TPUClusterResolver(tpu='') + tf_compat.config.experimental_connect_to_cluster(resolver) + tf_compat.tpu.experimental.initialize_tpu_system(resolver) + return tf_compat.distribute.TPUStrategy(resolver) + + +def assert_replica_values_are_close(test_case_obj, replica_context): + """Checks if all replica context tensors are near each other.""" + base_tensor = replica_context.values[0] + for t in replica_context.values[1:]: + test_case_obj.assertAllClose(base_tensor, t) + + def get_nd_test_batches(n: int): """Returns a list of input batches of dimension n.""" # The first two batches have a single element, the last batch has 2 elements. @@ -79,13 +98,76 @@ def compute_true_gradient_norms( trainable_vars = trainable_vars or input_model.trainable_variables for var in trainable_vars: jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) - reduction_axes = tf.range(1, len(jacobian.shape)) + reduction_axes = tf.range(1, tf.rank(jacobian)) sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes) sqr_norms.append(sqr_norm) sqr_norm_tsr = tf.stack(sqr_norms, axis=1) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) +def get_model_from_generator( + model_generator: type_aliases.ModelGenerator, + layer_generator: type_aliases.LayerGenerator, + input_dims: Union[int, List[int]], + output_dim: int, + is_eager: bool, +) -> tf.keras.Model: + """Creates a simple model from input specifications.""" + model = model_generator(layer_generator, input_dims, output_dim) + model.compile( + optimizer=tf.keras.optimizers.SGD(learning_rate=1.0), + loss=tf.keras.losses.MeanSquaredError( + reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE + ), + run_eagerly=is_eager, + ) + return model + + +def get_computed_and_true_norms_from_model( + model: tf.keras.Model, + per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], + num_microbatches: Optional[int], + x_batch: tf.Tensor, + weight_batch: Optional[tf.Tensor] = None, + rng_seed: int = 777, + registry: layer_registry.LayerRegistry = None, + partial: bool = False, +): + """Generates relevant norms from an input model and other specs.""" + trainable_vars = None + if partial: + # Gets the first layer with variables. + for l in model.layers: + trainable_vars = l.trainable_variables + if trainable_vars: + break + y_pred = model(x_batch) + y_batch = tf.ones_like(y_pred) + tf.keras.utils.set_random_seed(rng_seed) + computed_norms = clip_grads.compute_gradient_norms( + input_model=model, + x_batch=x_batch, + y_batch=y_batch, + weight_batch=weight_batch, + layer_registry=registry, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + trainable_vars=trainable_vars, + ) + tf.keras.utils.set_random_seed(rng_seed) + true_norms = compute_true_gradient_norms( + input_model=model, + x_batch=x_batch, + y_batch=y_batch, + weight_batch=weight_batch, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + trainable_vars=trainable_vars, + ) + return computed_norms, true_norms + + def get_computed_and_true_norms( model_generator: type_aliases.ModelGenerator, layer_generator: type_aliases.LayerGenerator, @@ -133,45 +215,23 @@ def get_computed_and_true_norms( model and layer generators. The second element contains the true clipped gradient norms under the aforementioned setting. """ - model = model_generator(layer_generator, input_dims, output_dim) - model.compile( - optimizer=tf.keras.optimizers.SGD(learning_rate=1.0), - loss=tf.keras.losses.MeanSquaredError( - reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE - ), - run_eagerly=is_eager, + model = get_model_from_generator( + model_generator=model_generator, + layer_generator=layer_generator, + input_dims=input_dims, + output_dim=output_dim, + is_eager=is_eager, ) - trainable_vars = None - if partial: - # Gets the first layer with variables. - for l in model.layers: - trainable_vars = l.trainable_variables - if trainable_vars: - break - y_pred = model(x_batch) - y_batch = tf.ones_like(y_pred) - tf.keras.utils.set_random_seed(rng_seed) - computed_norms = clip_grads.compute_gradient_norms( - input_model=model, - x_batch=x_batch, - y_batch=y_batch, - weight_batch=weight_batch, - layer_registry=registry, + return get_computed_and_true_norms_from_model( + model=model, per_example_loss_fn=per_example_loss_fn, num_microbatches=num_microbatches, - trainable_vars=trainable_vars, + x_batch=x_batch, + weight_batch=weight_batch, + rng_seed=rng_seed, + registry=registry, + partial=partial, ) - tf.keras.utils.set_random_seed(rng_seed) - true_norms = compute_true_gradient_norms( - model, - x_batch, - y_batch, - weight_batch, - per_example_loss_fn, - num_microbatches, - trainable_vars=trainable_vars, - ) - return (computed_norms, true_norms) # ============================================================================== @@ -234,7 +294,11 @@ def make_bow_model(layer_generator, input_dims, output_dim): # in eache example. emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) feature_embs = emb_layer(inputs) - reduction_axes = tf.range(1, len(feature_embs.shape)) + # Embeddings add one extra dimension to its inputs, which combined with the + # batch dimension at dimension 0, equals two additional dimensions compared + # to the number of input dimensions. Here, we want to reduce over the output + # space, but exclude the batch dimension. + reduction_axes = range(1, len(input_dims) + 2) example_embs = tf.expand_dims( tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 ) @@ -253,7 +317,11 @@ def make_dense_bow_model(layer_generator, input_dims, output_dim): input_dim=cardinality, output_dim=output_dim ) feature_embs = emb_layer(inputs) - reduction_axes = tf.range(1, len(feature_embs.shape)) + # Embeddings add one extra dimension to its inputs, which combined with the + # batch dimension at dimension 0, equals two additional dimensions compared + # to the number of input dimensions. Here, we want to reduce over the output + # space, but exclude the batch dimension. + reduction_axes = range(1, len(input_dims) + 2) example_embs = tf.expand_dims( tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 ) @@ -274,9 +342,21 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim): input_dim=cardinality, output_dim=output_dim ) feature_embs = emb_layer(inputs) - feature_weights = tf.random.uniform(tf.shape(feature_embs)) + # Use deterministic weights to avoid seeding issues on TPUs. + feature_shape = input_dims + [output_dim] + feature_weights = tf.expand_dims( + tf.reshape( + tf.range(np.product(feature_shape), dtype=tf.float32), + feature_shape, + ), + axis=0, + ) weighted_embs = feature_embs * feature_weights - reduction_axes = tf.range(1, len(weighted_embs.shape)) + # Embeddings add one extra dimension to its inputs, which combined with the + # batch dimension at dimension 0, equals two additional dimensions compared + # to the number of input dimensions. Here, we want to reduce over the output + # space, but exclude the batch dimension. + reduction_axes = range(1, len(input_dims) + 2) example_embs = tf.expand_dims( tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1 ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index a377287..5e5260a 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -19,9 +19,10 @@ py_test( size = "large", srcs = ["dense_test.py"], python_version = "PY3", - shard_count = 8, + shard_count = 12, srcs_version = "PY3", deps = [ + ":dense", "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", @@ -37,12 +38,13 @@ py_library( py_test( name = "embedding_test", - size = "large", srcs = ["embedding_test.py"], python_version = "PY3", - shard_count = 8, + shard_count = 12, srcs_version = "PY3", deps = [ + ":dense", + ":embedding", "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py index 0f4451c..c0b7f7b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py @@ -16,6 +16,7 @@ from absl.testing import parameterized import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense # ============================================================================== @@ -40,16 +41,30 @@ def get_dense_model_generators(): } +def get_dense_layer_registries(): + dense_registry = layer_registry.LayerRegistry() + dense_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation) + return { + 'dense_only': dense_registry, + 'default': layer_registry.make_default_layer_registry(), + } + + # ============================================================================== # Main tests. # ============================================================================== class GradNormTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.strategy = tf.distribute.get_strategy() + @parameterized.product( model_name=list(get_dense_model_generators().keys()), layer_name=list(get_dense_layer_generators().keys()), input_dim=[4], output_dim=[2], + layer_registry_name=list(get_dense_layer_registries().keys()), per_example_loss_fn=[None, common_test_utils.test_loss_fn], num_microbatches=[None, 1, 2], is_eager=[True, False], @@ -62,38 +77,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): layer_name, input_dim, output_dim, + layer_registry_name, per_example_loss_fn, num_microbatches, is_eager, partial, weighted, ): - model_generator = get_dense_model_generators()[model_name] - layer_generator = get_dense_layer_generators()[layer_name] + # Parse inputs to generate test data. x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim) - default_registry = layer_registry.make_default_layer_registry() + + # Load shared assets to all devices. + with self.strategy.scope(): + model = common_test_utils.get_model_from_generator( + model_generator=get_dense_model_generators()[model_name], + layer_generator=get_dense_layer_generators()[layer_name], + input_dims=input_dim, + output_dim=output_dim, + is_eager=is_eager, + ) + + # Define the main testing ops. These may be later compiled to a Graph op. + def test_op(x_batch, weight_batch): + return common_test_utils.get_computed_and_true_norms_from_model( + model=model, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch, + weight_batch=weight_batch if weighted else None, + registry=get_dense_layer_registries()[layer_registry_name], + partial=partial, + ) + + # TPUs can only run `tf.function`-decorated functions. + using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy) + if using_tpu: + test_op = tf.function(test_op, jit_compile=True, autograph=False) + + # TPUs use lower precision than CPUs, so we relax our criterion. + # E.g., one of the TPU runs generated the following results: + # + # computed_norm = 22.530651 + # true_norm = 22.570976 + # abs_diff = 0.04032516 + # rel_diff = 0.00178659 + # + # which is a reasonable level of error for computing gradient norms. + # Other trials also give an absolute (resp. relative) error of around + # 0.05 (resp. 0.0015). + rtol = 1e-2 if using_tpu else 1e-3 + atol = 1e-1 if using_tpu else 1e-2 + for x_batch, weight_batch in zip(x_batches, weight_batches): batch_size = x_batch.shape[0] if num_microbatches is not None and batch_size % num_microbatches != 0: continue - computed_norms, true_norms = ( - common_test_utils.get_computed_and_true_norms( - model_generator, - layer_generator, - input_dim, - output_dim, - per_example_loss_fn, - num_microbatches, - is_eager, - x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch, - weight_batch=weight_batch if weighted else None, - registry=default_registry, - partial=partial, - ) + # Set up the device ops and run the test. + computed_norms, true_norms = self.strategy.run( + test_op, args=(x_batch, weight_batch) ) + # TPUs return replica contexts, which must be unwrapped. + if using_tpu: + common_test_utils.assert_replica_values_are_close(self, computed_norms) + common_test_utils.assert_replica_values_are_close(self, true_norms) + computed_norms = computed_norms.values[0] + true_norms = true_norms.values[0] expected_size = num_microbatches or batch_size - self.assertEqual(computed_norms.shape[0], expected_size) - self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + self.assertEqual(tf.shape(computed_norms)[0], expected_size) + self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_tpu_test.py new file mode 100644 index 0000000..d3ae939 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_tpu_test.py @@ -0,0 +1,29 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense_test + + +class GradNormTpuTest(dense_test.GradNormTest): + + def setUp(self): + super().setUp() + self.strategy = ctu.create_tpu_strategy() + self.assertIn('TPU', self.strategy.extended.worker_devices[0]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py index 2b0887b..057ef39 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py @@ -62,12 +62,46 @@ def embedding_layer_computation( "'_use_one_hot_matmul' is not supported." ) input_ids = tf.cast(*input_args, tf.int32) + if len(layer_instance.trainable_variables) != 1: + raise ValueError( + "Only layer instances with only one set of trainable variables" + "are permitted." + ) base_vars = layer_instance.trainable_variables[0] tape.watch(base_vars) outputs = tf.nn.embedding_lookup(base_vars, input_ids) def sqr_norm_fn(base_vars_grads): - # Get a 1D tensor of the row indices. + """Fast square norm function for Keras embedding layers. + + Args: + base_vars_grads: A list of batched embedding gradients. + + Returns: + A 1D `tf.Tensor` of squared gradient norms. + + NOTE: to help understand the code, we document in the function body what + the expected intermediate variables are for the below running example: + + input_ids = [[1, 1, 2], [0], [2, 0]] + base_vars_grads.indices = [1, 1, 2, 0, 2, 0] + base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]] + + For ease of reference, we also list these variables below: + + row_indices = [[0], [0], [0], [1], [2], [2]] + slice_indices = [[1], [1], [2], [0], [2], [0]] + paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + new_index_positions = [0, 0, 1, 2, 3, 4] + num_unique_paired_indices = 5 + unique_batch_ids = [0, 0, 1, 2, 2] + summed_gradients + = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1] + = [[0.4], [0.3], [0.1], [0.3], [0.1]] + sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01] + """ + # We first get a 1D tensor of the row indices. nrows = tf.shape(input_ids)[0] if isinstance(input_ids, tf.RaggedTensor): row_indices = tf.expand_dims( @@ -81,16 +115,19 @@ def embedding_layer_computation( raise NotImplementedError( "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ ) - row_indices = tf.cast(row_indices, tf.int32) + row_indices = tf.cast(row_indices, tf.int64) if num_microbatches is not None: - microbatch_size = tf.cast(nrows / num_microbatches, tf.int32) + microbatch_size = tf.cast(nrows / num_microbatches, tf.int64) nrows = num_microbatches row_indices = tf.cast( - tf.math.floordiv(row_indices, microbatch_size), tf.int32 + tf.math.floordiv(row_indices, microbatch_size), tf.int64 ) - # Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()` - # call. The sum is reduced by the repeated embedding indices and batch - # index. It is adapted from the logic in: + # NOTE: expected values for the running example above are + # row_indices = [[0], [0], [0], [1], [2], [2]] + + # Now, sum-reduce the `IndexedSlices` that is the result of a + # `tape.gradient()` call. The sum is reduced by the repeated embedding + # indices and batch index. It is adapted from the logic in: # tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices if not isinstance(base_vars_grads, tf.IndexedSlices): raise NotImplementedError( @@ -105,20 +142,39 @@ def embedding_layer_computation( (unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2( x=paired_indices, axis=[0] ) + # NOTE: expected values for the running example above are + # slice_indices = [[1], [1], [2], [0], [2], [0]] + # paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + # unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + # new_index_positions = [0, 0, 1, 2, 3, 4] + + # Next, sum according to the new positions and compute the squared + # gradient norms. Oddly enough, not sorting + # these indices will break tensor shape inference logic on TPUs. + num_unique_paired_indices = tf.shape(unique_paired_indices)[0] unique_batch_ids = unique_paired_indices[:, 0] summed_gradients = tf.math.unsorted_segment_sum( base_vars_grads.values, new_index_positions, - tf.shape(unique_paired_indices)[0], + num_unique_paired_indices, ) - # Compute the squared gradient norms at the per-example level. sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1) - summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0]) - return tf.sparse.segment_sum( + # NOTE: expected values for the running example above are + # num_unique_paired_indices = 5 + # unique_batch_ids = [0, 0, 1, 2, 2] + # summed_gradients + # = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1] + # = [[0.4], [0.3], [0.1], [0.3], [0.1]] + # sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01] + + # Use a scatter-add strategy to ensure TPU compatibility. + result = tf.zeros([nrows]) + return tf.tensor_scatter_nd_add( + result, + tf.expand_dims(unique_batch_ids, axis=-1), sqr_gradient_sum, - summed_data_range, - tf.sort(unique_batch_ids), - num_segments=nrows, - ) # fill in empty inputs + ) + # NOTE: the expected output for the running example is + # [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1] return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py index c818afa..f8c1cf3 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py @@ -16,6 +16,8 @@ from absl.testing import parameterized import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding # ============================================================================== @@ -29,41 +31,50 @@ def get_embedding_model_generators(): } +def get_embedding_inputs(): + """Generates test pairs of the form (is_ragged, input_data).""" + return [ + # 2D inputs. + (False, [[0, 1]]), + (False, [[0, 1], [1, 1], [0, 0]]), + (True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]]), + (True, [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]]), + # 3D inputs. + (False, [[[0, 1]]]), + (False, [[[0, 1]], [[1, 1]], [[0, 0]]]), + (True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]), + (True, [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]]), + ] + + +def get_embedding_layer_registries(): + dbl_registry = layer_registry.LayerRegistry() + dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation) + dbl_registry.insert( + tf.keras.layers.Embedding, embedding.embedding_layer_computation + ) + return { + 'embed_and_dense': dbl_registry, + 'default': layer_registry.make_default_layer_registry(), + } + + # ============================================================================== # Main tests. # ============================================================================== class GradNormTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.strategy = tf.distribute.get_strategy() + # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment # supports them for embeddings. @parameterized.product( - x_batch=[ - # 2D inputs. - tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32), - tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32), - tf.ragged.constant( - [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32 - ), - tf.ragged.constant( - [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]], - dtype=tf.int32, - ), - # 3D inputs. - tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32), - tf.convert_to_tensor( - [[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32 - ), - tf.ragged.constant( - [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]], - dtype=tf.int32, - ), - tf.ragged.constant( - [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]], - dtype=tf.int32, - ), - ], + x_inputs=get_embedding_inputs(), model_name=list(get_embedding_model_generators().keys()), output_dim=[2], + layer_registry_name=list(get_embedding_layer_registries().keys()), per_example_loss_fn=[None, common_test_utils.test_loss_fn], num_microbatches=[None, 2], is_eager=[True, False], @@ -71,39 +82,74 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): ) def test_gradient_norms_on_various_models( self, - x_batch, + x_inputs, model_name, output_dim, + layer_registry_name, per_example_loss_fn, num_microbatches, is_eager, partial, ): - batch_size = x_batch.shape[0] - # The following are invalid test combinations, and are skipped. + # Parse inputs to generate test data. + is_ragged, input_data = x_inputs + embed_indices = ( + tf.ragged.constant(input_data, dtype=tf.int32) + if is_ragged + else tf.convert_to_tensor(input_data, dtype_hint=tf.int32) + ) + + # The following are invalid test combinations and, hence, are skipped. + batch_size = embed_indices.shape[0] + using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy) if ( - num_microbatches is not None and batch_size % num_microbatches != 0 - ) or ( - model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor) + (num_microbatches is not None and batch_size % num_microbatches != 0) + or (model_name == 'weighted_bow1' and is_ragged) + or ( + # Current clipping ops do not have corresponding TPU kernels. + using_tpu + and is_ragged + ) ): return - default_registry = layer_registry.make_default_layer_registry() - model_generator = get_embedding_model_generators()[model_name] - computed_norms, true_norms = ( - common_test_utils.get_computed_and_true_norms( - model_generator=model_generator, - layer_generator=None, - input_dims=x_batch.shape[1:], - output_dim=output_dim, - per_example_loss_fn=per_example_loss_fn, - num_microbatches=num_microbatches, - is_eager=is_eager, - x_batch=x_batch, - registry=default_registry, - partial=partial, - ) + + # Load shared assets to all devices. + with self.strategy.scope(): + model = common_test_utils.get_model_from_generator( + model_generator=get_embedding_model_generators()[model_name], + layer_generator=None, + input_dims=embed_indices.shape[1:], + output_dim=output_dim, + is_eager=is_eager, + ) + + # Define the main testing ops. These may be later compiled to a Graph op. + def test_op(x_batch): + return common_test_utils.get_computed_and_true_norms_from_model( + model=model, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + x_batch=x_batch, + registry=get_embedding_layer_registries()[layer_registry_name], + partial=partial, + ) + + # TPUs can only run `tf.function`-decorated functions. + if using_tpu: + test_op = tf.function(test_op, autograph=False) + + # Set up the device ops and run the test. + computed_norms, true_norms = self.strategy.run( + test_op, args=(embed_indices,) ) - self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size) + # TPUs return replica contexts, which must be unwrapped. + if using_tpu: + common_test_utils.assert_replica_values_are_close(self, computed_norms) + common_test_utils.assert_replica_values_are_close(self, true_norms) + computed_norms = computed_norms.values[0] + true_norms = true_norms.values[0] + expected_size = num_microbatches or batch_size + self.assertEqual(tf.shape(computed_norms)[0], expected_size) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py new file mode 100644 index 0000000..1b9b393 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py @@ -0,0 +1,29 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import embedding_test + + +class GradNormTpuTest(embedding_test.GradNormTest): + + def setUp(self): + super().setUp() + self.strategy = common_test_utils.create_tpu_strategy() + self.assertIn('TPU', self.strategy.extended.worker_devices[0]) + + +if __name__ == '__main__': + tf.test.main()