From bcc0d4927ece8f2499e99de01fd9d440c729d045 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 11 Sep 2023 13:17:34 -0700 Subject: [PATCH] Implement and test a registry function for `tfm.nlp.layers.OnDeviceEmbedding`. This CL also moves the common embedding `sqr_norm_fn` logic between `tf.keras.layers.Embedding` and `tfm.nlp.layers.OnDeviceEmbedding` into a new registry function utility file. PiperOrigin-RevId: 564481407 --- requirements.txt | 1 + setup.py | 1 + .../common_test_utils.py | 23 +-- .../registry_functions/BUILD | 37 +++- .../registry_functions/embedding.py | 108 +----------- .../registry_functions/embedding_test.py | 6 +- .../nlp_on_device_embedding.py | 70 ++++++++ .../nlp_on_device_embedding_test.py | 159 ++++++++++++++++++ .../nlp_on_device_embedding_tpu_test.py | 29 ++++ .../registry_function_utils.py | 131 +++++++++++++++ 10 files changed, 443 insertions(+), 122 deletions(-) create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding.py create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_test.py create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/registry_function_utils.py diff --git a/requirements.txt b/requirements.txt index 9dc7787..7d73be8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,4 @@ tensorflow-datasets~=4.5 tensorflow-estimator~=2.4 tensorflow-probability~=0.20.0 tensorflow~=2.4 +tf-models-official~=2.13 diff --git a/setup.py b/setup.py index debab5f..376e345 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ setup( 'tensorflow-estimator~=2.4', 'tensorflow-probability~=0.20.0', 'tensorflow~=2.4', + 'tf-models-official~=2.13', ], packages=find_packages(), ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py index e5f5e09..804e877 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py @@ -294,15 +294,13 @@ def make_bow_model( output_dims: List[int], ) -> tf.keras.Model: """Creates a simple embedding bow model.""" - del layer_generator - inputs = tf.keras.Input(shape=input_dims) + inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32) # For the Embedding layer, input_dim is the vocabulary size. This should # be distinguished from the input_dim argument, which is the number of ids # in eache example. if len(output_dims) != 1: raise ValueError('Expected `output_dims` to be of size 1.') - output_dim = output_dims[0] - emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) + emb_layer = layer_generator(input_dims, output_dims) feature_embs = emb_layer(inputs) # Embeddings add one extra dimension to its inputs, which combined with the # batch dimension at dimension 0, equals two additional dimensions compared @@ -321,18 +319,13 @@ def make_dense_bow_model( output_dims: List[int], ) -> tf.keras.Model: """Creates an embedding bow model with a `Dense` layer.""" - del layer_generator - inputs = tf.keras.Input(shape=input_dims) + inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32) # For the Embedding layer, input_dim is the vocabulary size. This should # be distinguished from the input_dim argument, which is the number of ids # in eache example. - cardinality = 10 + emb_layer = layer_generator(input_dims, output_dims) if len(output_dims) != 1: raise ValueError('Expected `output_dims` to be of size 1.') - output_dim = output_dims[0] - emb_layer = tf.keras.layers.Embedding( - input_dim=cardinality, output_dim=output_dim - ) feature_embs = emb_layer(inputs) # Embeddings add one extra dimension to its inputs, which combined with the # batch dimension at dimension 0, equals two additional dimensions compared @@ -353,18 +346,14 @@ def make_weighted_bow_model( ) -> tf.keras.Model: """Creates a weighted embedding bow model.""" # NOTE: This model only accepts dense input tensors. - del layer_generator - inputs = tf.keras.Input(shape=input_dims) + inputs = tf.keras.Input(shape=input_dims, dtype=tf.int32) # For the Embedding layer, input_dim is the vocabulary size. This should # be distinguished from the input_dim argument, which is the number of ids # in eache example. - cardinality = 10 + emb_layer = layer_generator(input_dims, output_dims) if len(output_dims) != 1: raise ValueError('Expected `output_dims` to be of size 1.') output_dim = output_dims[0] - emb_layer = tf.keras.layers.Embedding( - input_dim=cardinality, output_dim=output_dim - ) feature_embs = emb_layer(inputs) # Use deterministic weights to avoid seeding issues on TPUs. feature_shape = input_dims + [output_dim] diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD index e6a948f..358d8c3 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -2,6 +2,13 @@ package( default_visibility = ["//visibility:public"], ) +py_library( + name = "registry_function_utils", + srcs = ["registry_function_utils.py"], + srcs_version = "PY3", + deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"], +) + py_library( name = "einsum_utils", srcs = ["einsum_utils.py"], @@ -45,7 +52,10 @@ py_library( name = "embedding", srcs = ["embedding.py"], srcs_version = "PY3", - deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"], + deps = [ + ":registry_function_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", + ], ) py_test( @@ -63,6 +73,31 @@ py_test( ], ) +py_library( + name = "nlp_on_device_embedding", + srcs = ["nlp_on_device_embedding.py"], + srcs_version = "PY3", + deps = [ + ":registry_function_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", + ], +) + +py_test( + name = "nlp_on_device_embedding_test", + srcs = ["nlp_on_device_embedding_test.py"], + python_version = "PY3", + shard_count = 6, + srcs_version = "PY3", + deps = [ + ":dense", + ":nlp_on_device_embedding", + "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + ], +) + py_library( name = "layer_normalization", srcs = ["layer_normalization.py"], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py index 13ab5f5..495d74e 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py @@ -16,6 +16,7 @@ from typing import Any, Mapping, Tuple, Union import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils def embedding_layer_computation( @@ -71,110 +72,11 @@ def embedding_layer_computation( tape.watch(base_vars) outputs = tf.nn.embedding_lookup(base_vars, input_ids) - def sqr_norm_fn(base_vars_grads): - """Fast square norm function for Keras embedding layers. - - Args: - base_vars_grads: A list of batched embedding gradients. - - Returns: - A 1D `tf.Tensor` of squared gradient norms. - - NOTE: to help understand the code, we document in the function body what - the expected intermediate variables are for the below running example: - - input_ids = [[1, 1, 2], [0], [2, 0]] - base_vars_grads.indices = [1, 1, 2, 0, 2, 0] - base_vars_grads.values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]] - - For ease of reference, we also list these variables below: - - row_indices = [[0], [0], [0], [1], [2], [2]] - slice_indices = [[1], [1], [2], [0], [2], [0]] - paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] - unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] - new_index_positions = [0, 0, 1, 2, 3, 4] - num_unique_paired_indices = 5 - unique_batch_ids = [0, 0, 1, 2, 2] - summed_gradients - = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1] - = [[0.4], [0.3], [0.1], [0.3], [0.1]] - sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01] - """ - # We first get a 1D tensor of the row indices. - nrows = tf.shape(input_ids)[0] - if isinstance(input_ids, tf.RaggedTensor): - row_indices = tf.expand_dims( - input_ids.merge_dims(1, -1).value_rowids(), axis=-1 - ) - elif isinstance(input_ids, tf.Tensor): - ncols = tf.reduce_prod(tf.shape(input_ids)[1:]) - repeats = tf.repeat(ncols, nrows) - row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1]) - else: - raise NotImplementedError( - "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ - ) - row_indices = tf.cast(row_indices, tf.int64) - if num_microbatches is not None: - microbatch_size = tf.cast(nrows / num_microbatches, tf.int64) - nrows = num_microbatches - row_indices = tf.cast( - tf.math.floordiv(row_indices, microbatch_size), tf.int64 - ) - # NOTE: expected values for the running example above are - # row_indices = [[0], [0], [0], [1], [2], [2]] - - # Now, sum-reduce the `IndexedSlices` that is the result of a - # `tape.gradient()` call. The sum is reduced by the repeated embedding - # indices and batch index. It is adapted from the logic in: - # tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices - if not isinstance(base_vars_grads, tf.IndexedSlices): - raise NotImplementedError( - "Cannot parse embedding gradients of type: %s" - % base_vars_grads.__class__.__name__ - ) - slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1) - paired_indices = tf.concat( - [tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)], - axis=1, - ) - (unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2( - x=paired_indices, axis=[0] - ) - # NOTE: expected values for the running example above are - # slice_indices = [[1], [1], [2], [0], [2], [0]] - # paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] - # unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] - # new_index_positions = [0, 0, 1, 2, 3, 4] - - # Next, sum according to the new positions and compute the squared - # gradient norms. Oddly enough, not sorting - # these indices will break tensor shape inference logic on TPUs. - num_unique_paired_indices = tf.shape(unique_paired_indices)[0] - unique_batch_ids = unique_paired_indices[:, 0] - summed_gradients = tf.math.unsorted_segment_sum( + def sqr_norm_fn(base_vars_grads: tf.IndexedSlices): + return registry_function_utils.embedding_sqr_norm_fn( base_vars_grads.values, - new_index_positions, - num_unique_paired_indices, + input_ids, + num_microbatches, ) - sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1) - # NOTE: expected values for the running example above are - # num_unique_paired_indices = 5 - # unique_batch_ids = [0, 0, 1, 2, 2] - # summed_gradients - # = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1] - # = [[0.4], [0.3], [0.1], [0.3], [0.1]] - # sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01] - - # Use a scatter-add strategy to ensure TPU compatibility. - result = tf.zeros([nrows]) - return tf.tensor_scatter_nd_add( - result, - tf.expand_dims(unique_batch_ids, axis=-1), - sqr_gradient_sum, - ) - # NOTE: the expected output for the running example is - # [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1] return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py index 44de2e4..374afd6 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py @@ -115,9 +115,13 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase): # Load shared assets to all devices. with self.strategy.scope(): + + def embed_layer_generator(_, output_dims): + return tf.keras.layers.Embedding(10, *output_dims) + model = common_test_utils.get_model_from_generator( model_generator=get_embedding_model_generators()[model_name], - layer_generator=None, + layer_generator=embed_layer_generator, input_dims=embed_indices.shape[1:], output_dims=[output_dim], is_eager=is_eager, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding.py new file mode 100644 index 0000000..084c4eb --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding.py @@ -0,0 +1,70 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast clipping function for `tfm.nlp.layers.OnDeviceEmbedding`.""" + +from typing import Any, Dict, Optional, Tuple +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import registry_function_utils + + +def nlp_on_device_embedding_layer_computation( + layer_instance: tf.keras.layers.Layer, + input_args: Tuple[Any, ...], + input_kwargs: Dict[str, Any], + tape: tf.GradientTape, + num_microbatches: Optional[tf.Tensor] = None, +) -> type_aliases.RegistryFunctionOutput: + """Registry function for `tfm.nlp.layers.OnDeviceEmbedding`. + + Args: + layer_instance: A `tfm.nlp.layers.OnDeviceEmbedding` instance. + input_args: See `dense_layer_computation()` in `dense.py`. + input_kwargs: See `dense_layer_computation()` in `dense.py`. + tape: See `dense_layer_computation()` in `dense.py`. + num_microbatches: See `dense_layer_computation()` in `dense.py`. + + Returns: + See `dense_layer_computation()` in `dense.py`. + """ + if input_kwargs: + raise ValueError("Embedding layer calls should not receive kwargs.") + del input_kwargs + if len(input_args) != 1: + raise ValueError("Only layer inputs of length 1 are permitted.") + if hasattr(layer_instance, "_use_one_hot"): + if layer_instance._use_one_hot: # pylint: disable=protected-access + raise NotImplementedError( + "The embedding feature '_use_one_hot' is not supported." + ) + # NOTE: Since the implementation of `tfm.nlp.layers.OnDeviceEmbedding` uses + # `.set_shape()`, we can assume that inputs are not ragged. + input_ids = tf.cast(*input_args, tf.int32) + if len(layer_instance.trainable_variables) != 1: + raise ValueError( + "Only layer instances with only one set of trainable variables" + "are permitted." + ) + base_vars = layer_instance.trainable_variables[0] + tape.watch(base_vars) + outputs = layer_instance(input_ids) + + def sqr_norm_fn(base_vars_grads: tf.IndexedSlices): + return registry_function_utils.embedding_sqr_norm_fn( + base_vars_grads.values, + input_ids, + num_microbatches, + ) + + return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_test.py new file mode 100644 index 0000000..41f5914 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_test.py @@ -0,0 +1,159 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf +import tensorflow_models as tfm +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding + + +# ============================================================================== +# Helper functions. +# ============================================================================== +def get_nlp_on_device_embedding_model_generators(): + return { + 'bow1': common_test_utils.make_bow_model, + 'bow2': common_test_utils.make_dense_bow_model, + 'weighted_bow1': common_test_utils.make_weighted_bow_model, + } + + +def get_nlp_on_device_embedding_inputs(): + """Generates input_data.""" + return [ + # 2D inputs. + [[0, 1]], + [[0, 1], [1, 1], [0, 0]], + # 3D inputs. + [[[0, 1]]], + [[[0, 1]], [[1, 1]], [[0, 0]]], + ] + + +def get_nlp_on_device_embedding_layer_registries(): + dbl_registry = layer_registry.LayerRegistry() + dbl_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation) + dbl_registry.insert( + tfm.nlp.layers.OnDeviceEmbedding, + nlp_on_device_embedding.nlp_on_device_embedding_layer_computation + ) + return { + 'embed_and_dense': dbl_registry, + } + + +# ============================================================================== +# Main tests. +# ============================================================================== +class GradNormTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.strategy = tf.distribute.get_strategy() + + # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment + # supports them for embeddings. + @parameterized.product( + input_data=get_nlp_on_device_embedding_inputs(), + scale_factor=[None, 0.5, 1.0], + model_name=list( + get_nlp_on_device_embedding_model_generators().keys() + ), + output_dim=[2], + layer_registry_name=list( + get_nlp_on_device_embedding_layer_registries().keys() + ), + num_microbatches=[None, 2], + is_eager=[True, False], + partial=[True, False], + ) + def test_gradient_norms_on_various_models( + self, + input_data, + scale_factor, + model_name, + output_dim, + layer_registry_name, + num_microbatches, + is_eager, + partial, + ): + # Parse inputs to generate test data. + embed_indices = tf.convert_to_tensor(input_data, dtype_hint=tf.int32) + + # The following are invalid test combinations and, hence, are skipped. + batch_size = embed_indices.shape[0] + using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy) + if num_microbatches is not None and batch_size % num_microbatches != 0: + return + + # Load shared assets to all devices. + with self.strategy.scope(): + + def embed_layer_generator(_, output_dims): + return tfm.nlp.layers.OnDeviceEmbedding( + 10, + *output_dims, + scale_factor=scale_factor + ) + + model = common_test_utils.get_model_from_generator( + model_generator=( + get_nlp_on_device_embedding_model_generators()[model_name] + ), + layer_generator=embed_layer_generator, + input_dims=embed_indices.shape[1:], + output_dims=[output_dim], + is_eager=is_eager, + ) + + # Define the main testing ops. These may be later compiled to a Graph op. + def test_op(x_batch): + return common_test_utils.get_computed_and_true_norms_from_model( + model=model, + per_example_loss_fn=None, + num_microbatches=num_microbatches, + x_batch=x_batch, + registry=( + get_nlp_on_device_embedding_layer_registries()[ + layer_registry_name + ] + ), + partial=partial, + ) + + # TPUs can only run `tf.function`-decorated functions. + if using_tpu: + test_op = tf.function(test_op, autograph=False) + + # Set up the device ops and run the test. + computed_norms, true_norms = self.strategy.run( + test_op, args=(embed_indices,) + ) + # TPUs return replica contexts, which must be unwrapped. + if using_tpu: + common_test_utils.assert_replica_values_are_close(self, computed_norms) + common_test_utils.assert_replica_values_are_close(self, true_norms) + computed_norms = computed_norms.values[0] + true_norms = true_norms.values[0] + expected_size = num_microbatches or batch_size + self.assertEqual(tf.shape(computed_norms)[0], expected_size) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py new file mode 100644 index 0000000..2a447b4 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py @@ -0,0 +1,29 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import nlp_on_device_embedding_test + + +class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): + + def setUp(self): + super().setUp() + self.strategy = common_test_utils.create_tpu_strategy() + self.assertIn('TPU', self.strategy.extended.worker_devices[0]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/registry_function_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/registry_function_utils.py new file mode 100644 index 0000000..0fa88c6 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/registry_function_utils.py @@ -0,0 +1,131 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Useful utility functions for implementing registry functions.""" + +from typing import Union +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases + + +def embedding_sqr_norm_fn( + grad_values: tf.Tensor, + input_ids: Union[tf.Tensor, tf.RaggedTensor], + num_microbatches: Union[tf.Tensor, None] = None, +)->type_aliases.RegistryFunctionOutput: + """Fast square norm function for general embedding layers. + + Args: + grad_values: Batched embedding gradient values. These come from the 'values' + field of the `IndexedSlices` object representing the embedding gradient. + input_ids: A `tf.Tensor` of queried embedding ids. The values in this input + are the same as the `indices` field of the `IndexedSlices` object + representing the embedding gradient. + num_microbatches: An optional number of microbatches. + + Returns: + A 1D `tf.Tensor` of squared gradient norms. + + NOTE: to help understand the code, we document in the function body what + the expected intermediate variables are for the below running example: + + input_ids = [[1, 1, 2], [0], [2, 0]] + grads_values = [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]] + + For ease of reference, we also list these variables below: + + row_indices = [[0], [0], [0], [1], [2], [2]] + flattened_indices = [[1], [1], [2], [0], [2], [0]] + paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + new_index_positions = [0, 0, 1, 2, 3, 4] + num_unique_paired_indices = 5 + unique_batch_ids = [0, 0, 1, 2, 2] + summed_gradients + = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1] + = [[0.4], [0.3], [0.1], [0.3], [0.1]] + sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01] + """ + # We first get a 1D tensor of the row indices. + nrows = tf.shape(input_ids)[0] + if isinstance(input_ids, tf.RaggedTensor): + row_indices = tf.expand_dims( + input_ids.merge_dims(1, -1).value_rowids(), axis=-1 + ) + elif isinstance(input_ids, tf.Tensor): + ncols = tf.reduce_prod(tf.shape(input_ids)[1:]) + repeats = tf.repeat(ncols, nrows) + row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1]) + else: + raise NotImplementedError( + "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ + ) + row_indices = tf.cast(row_indices, tf.int64) + if num_microbatches is not None: + microbatch_size = tf.cast(nrows / num_microbatches, tf.int64) + nrows = num_microbatches + row_indices = tf.cast( + tf.math.floordiv(row_indices, microbatch_size), tf.int64 + ) + # NOTE: expected values for the running example above are + # row_indices = [[0], [0], [0], [1], [2], [2]] + + # Now, sum-reduce the `IndexedSlices` that is the result of a + # `tape.gradient()` call. The sum is reduced by the repeated embedding + # indices and batch index. It is adapted from the logic in: + # tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices + flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1) + paired_indices = tf.concat( + [ + tf.cast(row_indices, tf.int64), + tf.cast(flattened_indices, tf.int64) + ], + axis=1, + ) + (unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2( + x=paired_indices, axis=[0] + ) + # NOTE: expected values for the running example above are + # flattened_indices = [[1], [1], [2], [0], [2], [0]] + # paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + # unique_paired_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + # new_index_positions = [0, 0, 1, 2, 3, 4] + + # Next, sum according to the new positions and compute the squared + # gradient norms. Oddly enough, not sorting + # these indices will break tensor shape inference logic on TPUs. + num_unique_paired_indices = tf.shape(unique_paired_indices)[0] + unique_batch_ids = unique_paired_indices[:, 0] + summed_gradients = tf.math.unsorted_segment_sum( + grad_values, + new_index_positions, + num_unique_paired_indices, + ) + sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1) + # NOTE: expected values for the running example above are + # num_unique_paired_indices = 5 + # unique_batch_ids = [0, 0, 1, 2, 2] + # summed_gradients + # = [0.2 + 0.2, 0.3, 0.1, 0.3, 0.1] + # = [[0.4], [0.3], [0.1], [0.3], [0.1]] + # sqr_gradient_sum = [0.16, 0.09, 0.01, 0.09, 0.01] + + # Use a scatter-add strategy to ensure TPU compatibility. + result = tf.zeros([nrows]) + return tf.tensor_scatter_nd_add( + result, + tf.expand_dims(unique_batch_ids, axis=-1), + sqr_gradient_sum, + ) + # NOTE: the expected output for the running example is + # [0.16 + 0.09, 0.01, 0.09 + 0.01] = [0.25, 0.01, 0.1]