diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD index 912b439..8698b51 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD @@ -21,5 +21,8 @@ py_library( py_library( name = "layer_registry", srcs = ["layer_registry.py"], - deps = [":type_aliases"], + deps = [ + ":type_aliases", + "//tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions:embedding", + ], ) diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py index 6c7d9a4..ed015f1 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py @@ -17,6 +17,7 @@ from typing import Type import tensorflow as tf from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases +from tensorflow_privacy.privacy.sparsity_preserving_noise.registry_functions import embedding # ============================================================================== @@ -49,3 +50,15 @@ class LayerRegistry: layer_key = hash(layer_class) self._layer_class_dict[layer_key] = layer_class self._registry[layer_key] = layer_registry_function + + +# ============================================================================== +# Main factory methods +# ============================================================================== +def make_default_layer_registry() -> LayerRegistry: + registry = LayerRegistry() + registry.insert( + tf.keras.layers.Embedding, + embedding.embedding_layer_contribution_histogram, + ) + return registry diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/BUILD b/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/BUILD new file mode 100644 index 0000000..f998797 --- /dev/null +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +py_library( + name = "embedding", + srcs = ["embedding.py"], + deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases"], +) + +py_test( + name = "embedding_test", + srcs = ["embedding_test.py"], + deps = [":embedding"], +) diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/embedding.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/embedding.py new file mode 100644 index 0000000..fed47cd --- /dev/null +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/embedding.py @@ -0,0 +1,228 @@ +# Copyright 2024, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compute the contribution histogram for an embedding layer.""" + +from typing import Optional +import tensorflow as tf +from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases + + +def embedding_layer_contribution_histogram( + layer_instance: tf.keras.layers.Embedding, + input_args: type_aliases.InputArgs, + input_kwargs: type_aliases.InputKwargs, + num_microbatches: Optional[tf.Tensor] = None, +) -> dict[str, type_aliases.ContributionCountHistogramFn]: + """Registry function for `tf.keras.layers.Embedding`. + + Args: + layer_instance: A `tf.keras.layers.Embedding` instance. + input_args: A `tuple` containing the first part of `layer_instance` input. + Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return + a valid output. + input_kwargs: A `tuple` containing the second part of `layer_instance` + input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should + return a valid output. + num_microbatches: An optional numeric value or scalar `tf.Tensor` for + indicating whether and how the losses are grouped into microbatches. If + not None, num_microbatches must divide the batch size. + + Returns: + A dict mapping the name of the trainable variable to a function with + signature `(tf.IndexedSlices) -> tf.SparseTensor`. The function takes a + `tf.IndexedSlices` object representing the gradient for that variable and + returns a `tf.SparseTensor` representing the normalized (so that each user + contributes 1) contribution counts histogram per user for each embedding + vector. + """ + if input_kwargs: + raise ValueError("Embedding layer calls should not receive kwargs.") + del input_kwargs # Unused in embedding layer calls. + if not input_args or len(input_args) != 1: + raise ValueError("Only layer inputs of length 1 are permitted.") + if hasattr(layer_instance, "sparse"): # for backwards compatibility + if layer_instance.sparse: + raise NotImplementedError("Sparse output tensors are not supported.") + if isinstance(input_args[0], tf.SparseTensor): + raise NotImplementedError("Sparse input tensors are not supported.") + + # Disable experimental features. + if hasattr(layer_instance, "_use_one_hot_matmul"): + if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access + raise NotImplementedError( + "The experimental embedding feature " + "'_use_one_hot_matmul' is not supported." + ) + input_ids = tf.squeeze(tf.cast(*input_args, tf.int32)) + + def count_contributions_fn( + grad: type_aliases.SparseGradient, + ) -> type_aliases.ContributionCountHistogram: + return embedding_layer_contribution_histogram_fn( + grad, + input_ids, + layer_instance.input_dim, + num_microbatches, + ) + + if ( + not layer_instance.trainable_variables + or len(layer_instance.trainable_variables) != 1 + ): + raise ValueError( + "Embedding layer must have exactly one trainable variable." + ) + return {layer_instance.trainable_variables[0].name: count_contributions_fn} + + +def embedding_layer_contribution_histogram_fn( + grad: type_aliases.SparseGradient, + input_ids: tf.Tensor, + vocab_size: Optional[tf.Tensor], + num_microbatches: Optional[tf.Tensor] = None, +) -> type_aliases.ContributionCountHistogram: + """Computes the normalized contribution counts histogram for embedding layer. + + NOTE: to help understand the code, we document in the function body what the + expected intermediate variables are for the below running example: + + grad = None + input_ids = [[1, 1, 2], [0], [2, 0]] + vocab_size = 3 + num_microbatches = None + + For ease of reference, we also list these variables below: + + row_indices = [[0], [0], [0], [1], [2], [2]] + flattened_indices = [[1], [1], [2], [0], [2], [0]] + paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + linearized_pair_indices = [1 1 2 3 8 6] + contribution_counts_linearized_indices = [1 2 3 8 6] + contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + contribution_counts_values = [2 1 1 1 1] + user_normalized_contribution_counts = tf.SparseTensor( + indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]], + values=[0.67, 0.33, 1., 0.5, 0.5,] + shape=(3, 3) + ) + contribution_histogram = tf.SparseTensor( + indices=[[0], [1], [2]], + values=[1.5, 0.67, 0.83], + shape=(3,) + ) + + + Args: + grad: The gradient of the layer. (unused for embedding layer) + input_ids: The input ids used to compute the embeddings. + vocab_size: The vocabulary size of the embedding layer. + num_microbatches: An optional numeric value or scalar `tf.Tensor` for + indicating whether and how the losses are grouped into microbatches. If + not None, num_microbatches must divide the batch size. + + Returns: + A `tf.SparseTensor` representing the normalized (so that each user + contributes 1) contribution counts histogram per user for each embedding + vector. + + Raises: + NotImplementedError: If the input_ids is not a `tf.Tensor` or + `tf.RaggedTensor`. + """ + del grad # unused. + + nrows = tf.shape(input_ids)[0] + if isinstance(input_ids, tf.RaggedTensor): + row_indices = tf.expand_dims( + input_ids.merge_dims(1, -1).value_rowids(), axis=-1 + ) + elif isinstance(input_ids, tf.Tensor): + ncols = tf.reduce_prod(tf.shape(input_ids)[1:]) + repeats = tf.repeat(ncols, nrows) + row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1]) + row_indices = tf.cast(row_indices, tf.int64) + else: + raise NotImplementedError( + "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ + ) + + if num_microbatches is not None: + tf.debugging.assert_equal( + nrows % num_microbatches, + 0, + "num_microbatches must divide the batch size.", + ) + microbatch_size = tf.cast(nrows / num_microbatches, tf.int64) + nrows = num_microbatches + row_indices = tf.cast( + tf.math.floordiv(row_indices, microbatch_size), tf.int64 + ) + # NOTE: expected values for the running example above are + # row_indices = [[0], [0], [0], [1], [2], [2]] + + flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1) + paired_indices = tf.concat( + [tf.cast(row_indices, tf.int64), tf.cast(flattened_indices, tf.int64)], + axis=1, + ) + # NOTE: expected values for the running example above are + # flattened_indices = [[1], [1], [2], [0], [2], [0]] + # paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + + transform = tf.cast(tf.stack([[vocab_size], [1]], axis=0), tf.int64) + linearized_pair_indices = tf.reshape( + tf.matmul(paired_indices, transform), (-1,) + ) + contribution_counts_linearized_indices, _, contribution_counts_values = ( + tf.unique_with_counts(linearized_pair_indices) + ) + contribution_counts_indices = tf.stack( + [ + contribution_counts_linearized_indices // vocab_size, + contribution_counts_linearized_indices % vocab_size, + ], + axis=1, + ) + contribution_counts = tf.sparse.SparseTensor( + contribution_counts_indices, + contribution_counts_values, + (nrows, vocab_size), + ) + contribution_counts = tf.sparse.reorder(contribution_counts) + # NOTE: expected values for the running example above are + # linearized_pair_indices = [1 1 2 3 8 6] + # contribution_counts_linearized_indices = [1 2 3 8 6] + # contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] + # contribution_counts_values = [2 1 1 1 1] + + user_normalized_contribution_counts = ( + contribution_counts + / tf.sparse.reduce_sum(contribution_counts, axis=-1, keepdims=True) + ) + contribution_histogram = tf.sparse.reduce_sum( + user_normalized_contribution_counts, axis=0, output_is_sparse=True + ) + # NOTE: expected values for the running example above are + # user_normalized_contribution_counts = tf.SparseTensor( + # indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]], + # values=[0.67, 0.33, 1., 0.5, 0.5,] + # shape=(3, 3) + # ) + # contribution_histogram = tf.SparseTensor( + # indices=[[0], [1], [2]], + # values=[1.5, 0.67, 0.83], + # shape=(3,) + # ) + + return tf.sparse.reshape(contribution_histogram, (-1,)) diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/embedding_test.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/embedding_test.py new file mode 100644 index 0000000..668a964 --- /dev/null +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions/embedding_test.py @@ -0,0 +1,270 @@ +# Copyright 2024, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for embedding.""" + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.sparsity_preserving_noise.registry_functions import embedding + + +class EmbeddingLayerWithMultipleTrainableVariables(tf.keras.layers.Embedding): + + def build(self, input_shape): + self.some_other_variable = self.add_weight( + name="some_other_variable", + shape=(10, 10), + trainable=True, + ) + super().build(input_shape) + + +class EmbeddingTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="rank2_input", + input_ids=tf.constant([[0], [0], [4], [2]]), + num_microbatches=None, + vocab_size=8, + expected_contribution_counts=tf.SparseTensor( + indices=[[0], [2], [4]], + values=tf.constant([2.0, 1.0, 1.0], dtype=tf.float64), + dense_shape=[8], + ), + ), + dict( + testcase_name="rank2_multi_input", + input_ids=tf.constant([[0, 2], [0, 2], [4, 5], [2, 3]]), + num_microbatches=None, + vocab_size=8, + expected_contribution_counts=tf.SparseTensor( + indices=[[0], [2], [3], [4], [5]], + values=tf.constant([1.0, 1.5, 0.5, 0.5, 0.5], dtype=tf.float64), + dense_shape=[8], + ), + ), + dict( + testcase_name="rank3_input", + input_ids=tf.constant( + [[[0], [2]], [[0], [2]], [[4], [5]], [[2], [3]]] + ), + num_microbatches=None, + vocab_size=8, + expected_contribution_counts=tf.SparseTensor( + indices=[[0], [2], [3], [4], [5]], + values=tf.constant([1.0, 1.5, 0.5, 0.5, 0.5], dtype=tf.float64), + dense_shape=[8], + ), + ), + dict( + testcase_name="ragged_input", + input_ids=tf.ragged.constant([[0, 2], [2], [2, 3, 4], [4, 5]]), + num_microbatches=None, + vocab_size=8, + expected_contribution_counts=tf.SparseTensor( + indices=[[0], [2], [3], [4], [5]], + values=tf.constant( + [0.5, 1.5 + 1.0 / 3, 1.0 / 3, 0.5 + 1.0 / 3, 0.5], + dtype=tf.float64, + ), + dense_shape=[8], + ), + ), + dict( + testcase_name="rank2_input_num_microbatches_2", + input_ids=tf.constant([[0], [0], [4], [2]]), + num_microbatches=2, + vocab_size=8, + expected_contribution_counts=tf.SparseTensor( + indices=[[0], [2], [4]], + values=tf.constant([1.0, 0.5, 0.5], dtype=tf.float64), + dense_shape=[8], + ), + ), + dict( + testcase_name="ragged_input_num_microbatches_2", + input_ids=tf.ragged.constant([[0, 2], [2], [2, 3, 4], [4, 5]]), + num_microbatches=2, + vocab_size=8, + expected_contribution_counts=tf.SparseTensor( + indices=[[0], [2], [3], [4], [5]], + values=tf.constant( + [1.0 / 3, 2.0 / 3 + 1.0 / 5, 1.0 / 5, 2.0 / 5, 1.0 / 5], + dtype=tf.float64, + ), + dense_shape=[8], + ), + ), + ) + def test_embedding_layer_contribution_histogram_fn( + self, + input_ids, + expected_contribution_counts, + vocab_size, + num_microbatches, + ): + grad = None + contribution_counts = embedding.embedding_layer_contribution_histogram_fn( + grad, input_ids, vocab_size, num_microbatches + ) + tf.debugging.assert_equal( + tf.sparse.to_dense(contribution_counts), + tf.sparse.to_dense(expected_contribution_counts), + ) + + @parameterized.named_parameters( + dict( + testcase_name="input_None", + input_ids=None, + ), + dict( + testcase_name="input_SparseTensor", + input_ids=tf.SparseTensor( + indices=[[0, 0]], + values=[0], + dense_shape=(3, 8), + ), + ), + dict( + testcase_name="input_list", + input_ids=[[0], [0], [1], [2]], + ), + dict( + testcase_name="num_microbatches_not_divisible", + input_ids=tf.constant([[0], [0], [4], [2]]), + num_microbatches=3, + ), + ) + def test_embedding_layer_contribution_histogram_fn_errors( + self, input_ids, num_microbatches=None + ): + with self.assertRaises( + (NotImplementedError, ValueError, tf.errors.InvalidArgumentError) + ): + embedding.embedding_layer_contribution_histogram_fn( + None, input_ids, 8, num_microbatches + ) + + @parameterized.named_parameters( + dict( + testcase_name="input_kwargs", + error_message="Embedding layer calls should not receive kwargs.", + input_kwargs={"foo": "bar"}, + ), + dict( + testcase_name="input_args_more_than_one", + error_message="Only layer inputs of length 1 are permitted.", + input_args=[tf.constant([0]), tf.constant([0])], + ), + dict( + testcase_name="input_args_none", + error_message="Only layer inputs of length 1 are permitted.", + ), + dict( + testcase_name="input_sparse", + error_message="Sparse input tensors are not supported.", + input_args=[ + tf.SparseTensor(indices=[[0, 0]], values=[0], dense_shape=(3, 8)) + ], + ), + dict( + testcase_name="layer_one_hot_matmul", + error_message=( + "The experimental embedding feature '_use_one_hot_matmul' is not" + " supported." + ), + input_args=[tf.constant([0])], + layer_kwargs={"use_one_hot_matmul": True}, + ), + dict( + testcase_name="layer_sparse", + error_message="Sparse output tensors are not supported.", + input_args=[tf.constant([0])], + layer_kwargs={"sparse": True}, + ), + ) + def test_embedding_layer_contribution_histogram_errors( + self, + error_message, + input_args=None, + input_kwargs=None, + layer_kwargs=None, + ): + layer_kwargs = layer_kwargs or {} + layer = tf.keras.layers.Embedding(input_dim=8, output_dim=4, **layer_kwargs) + layer.build(input_shape=(None, 1)) + with self.assertRaisesRegex( + (NotImplementedError, ValueError), + error_message, + ): + embedding.embedding_layer_contribution_histogram( + layer, input_args, input_kwargs, None + ) + + @parameterized.named_parameters( + dict( + testcase_name="zero_variables", + build=False, # unbuilt layer has no trainable variables + ), + dict( + testcase_name="two_variables", + build=True, + ), + ) + def test_embedding_layer_contribution_histogram_embedding_layer_invalid_trainable_variables( + self, build + ): + input_args = [tf.constant([0])] + input_kwargs = {} + layer = EmbeddingLayerWithMultipleTrainableVariables( + input_dim=8, output_dim=4 + ) + if build: + layer.build(input_shape=(None, 1)) + with self.assertRaisesRegex( + ValueError, + "Embedding layer must have exactly one trainable variable.", + ): + embedding.embedding_layer_contribution_histogram( + layer, input_args, input_kwargs + ) + + def test_embedding_layer_contribution_histogram_embedding(self): + input_args = [tf.constant([[0], [0], [4], [2]])] + input_kwargs = {} + layer = tf.keras.layers.Embedding(input_dim=8, output_dim=4) + layer.build(input_shape=(None, 1)) + + contribution_counts_fn_dict = ( + embedding.embedding_layer_contribution_histogram( + layer, input_args, input_kwargs + ) + ) + self.assertEqual(list(contribution_counts_fn_dict.keys()), ["embeddings:0"]) + contribution_counts_fn = contribution_counts_fn_dict["embeddings:0"] + dummy_gradient = None + contribution_counts = contribution_counts_fn(dummy_gradient) + expected_contribution_counts = tf.SparseTensor( + indices=[[0], [2], [4]], + values=tf.constant([2.0, 1.0, 1.0], dtype=tf.float64), + dense_shape=[8], + ) + tf.debugging.assert_equal( + tf.sparse.to_dense(contribution_counts), + tf.sparse.to_dense(expected_contribution_counts), + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py index c63d083..82c7e76 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py @@ -27,5 +27,5 @@ ContributionCountHistogramFn = Callable[ NumMicrobatches = int | tf.Tensor SparsityPreservingNoiseLayerRegistryFunction = Callable[ [tf.keras.layers.Layer, InputArgs, InputKwargs, NumMicrobatches | None], - ContributionCountHistogramFn, + dict[str, ContributionCountHistogramFn], ]