Sparsity Preserving DP-SGD in TF Privacy [4 of 5]

Add contribution count function for embedding layer.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 656091009
This commit is contained in:
A. Unique TensorFlower 2024-07-25 14:23:42 -07:00
parent 80802c248d
commit fc6f1dc5d1
6 changed files with 531 additions and 2 deletions

View file

@ -21,5 +21,8 @@ py_library(
py_library(
name = "layer_registry",
srcs = ["layer_registry.py"],
deps = [":type_aliases"],
deps = [
":type_aliases",
"//tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions:embedding",
],
)

View file

@ -17,6 +17,7 @@ from typing import Type
import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
from tensorflow_privacy.privacy.sparsity_preserving_noise.registry_functions import embedding
# ==============================================================================
@ -49,3 +50,15 @@ class LayerRegistry:
layer_key = hash(layer_class)
self._layer_class_dict[layer_key] = layer_class
self._registry[layer_key] = layer_registry_function
# ==============================================================================
# Main factory methods
# ==============================================================================
def make_default_layer_registry() -> LayerRegistry:
registry = LayerRegistry()
registry.insert(
tf.keras.layers.Embedding,
embedding.embedding_layer_contribution_histogram,
)
return registry

View file

@ -0,0 +1,15 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"])
py_library(
name = "embedding",
srcs = ["embedding.py"],
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases"],
)
py_test(
name = "embedding_test",
srcs = ["embedding_test.py"],
deps = [":embedding"],
)

View file

@ -0,0 +1,228 @@
# Copyright 2024, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute the contribution histogram for an embedding layer."""
from typing import Optional
import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
def embedding_layer_contribution_histogram(
layer_instance: tf.keras.layers.Embedding,
input_args: type_aliases.InputArgs,
input_kwargs: type_aliases.InputKwargs,
num_microbatches: Optional[tf.Tensor] = None,
) -> dict[str, type_aliases.ContributionCountHistogramFn]:
"""Registry function for `tf.keras.layers.Embedding`.
Args:
layer_instance: A `tf.keras.layers.Embedding` instance.
input_args: A `tuple` containing the first part of `layer_instance` input.
Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
a valid output.
input_kwargs: A `tuple` containing the second part of `layer_instance`
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
return a valid output.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
Returns:
A dict mapping the name of the trainable variable to a function with
signature `(tf.IndexedSlices) -> tf.SparseTensor`. The function takes a
`tf.IndexedSlices` object representing the gradient for that variable and
returns a `tf.SparseTensor` representing the normalized (so that each user
contributes 1) contribution counts histogram per user for each embedding
vector.
"""
if input_kwargs:
raise ValueError("Embedding layer calls should not receive kwargs.")
del input_kwargs # Unused in embedding layer calls.
if not input_args or len(input_args) != 1:
raise ValueError("Only layer inputs of length 1 are permitted.")
if hasattr(layer_instance, "sparse"): # for backwards compatibility
if layer_instance.sparse:
raise NotImplementedError("Sparse output tensors are not supported.")
if isinstance(input_args[0], tf.SparseTensor):
raise NotImplementedError("Sparse input tensors are not supported.")
# Disable experimental features.
if hasattr(layer_instance, "_use_one_hot_matmul"):
if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access
raise NotImplementedError(
"The experimental embedding feature "
"'_use_one_hot_matmul' is not supported."
)
input_ids = tf.squeeze(tf.cast(*input_args, tf.int32))
def count_contributions_fn(
grad: type_aliases.SparseGradient,
) -> type_aliases.ContributionCountHistogram:
return embedding_layer_contribution_histogram_fn(
grad,
input_ids,
layer_instance.input_dim,
num_microbatches,
)
if (
not layer_instance.trainable_variables
or len(layer_instance.trainable_variables) != 1
):
raise ValueError(
"Embedding layer must have exactly one trainable variable."
)
return {layer_instance.trainable_variables[0].name: count_contributions_fn}
def embedding_layer_contribution_histogram_fn(
grad: type_aliases.SparseGradient,
input_ids: tf.Tensor,
vocab_size: Optional[tf.Tensor],
num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.ContributionCountHistogram:
"""Computes the normalized contribution counts histogram for embedding layer.
NOTE: to help understand the code, we document in the function body what the
expected intermediate variables are for the below running example:
grad = None
input_ids = [[1, 1, 2], [0], [2, 0]]
vocab_size = 3
num_microbatches = None
For ease of reference, we also list these variables below:
row_indices = [[0], [0], [0], [1], [2], [2]]
flattened_indices = [[1], [1], [2], [0], [2], [0]]
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
linearized_pair_indices = [1 1 2 3 8 6]
contribution_counts_linearized_indices = [1 2 3 8 6]
contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
contribution_counts_values = [2 1 1 1 1]
user_normalized_contribution_counts = tf.SparseTensor(
indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]],
values=[0.67, 0.33, 1., 0.5, 0.5,]
shape=(3, 3)
)
contribution_histogram = tf.SparseTensor(
indices=[[0], [1], [2]],
values=[1.5, 0.67, 0.83],
shape=(3,)
)
Args:
grad: The gradient of the layer. (unused for embedding layer)
input_ids: The input ids used to compute the embeddings.
vocab_size: The vocabulary size of the embedding layer.
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
indicating whether and how the losses are grouped into microbatches. If
not None, num_microbatches must divide the batch size.
Returns:
A `tf.SparseTensor` representing the normalized (so that each user
contributes 1) contribution counts histogram per user for each embedding
vector.
Raises:
NotImplementedError: If the input_ids is not a `tf.Tensor` or
`tf.RaggedTensor`.
"""
del grad # unused.
nrows = tf.shape(input_ids)[0]
if isinstance(input_ids, tf.RaggedTensor):
row_indices = tf.expand_dims(
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
)
elif isinstance(input_ids, tf.Tensor):
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
repeats = tf.repeat(ncols, nrows)
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
row_indices = tf.cast(row_indices, tf.int64)
else:
raise NotImplementedError(
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
)
if num_microbatches is not None:
tf.debugging.assert_equal(
nrows % num_microbatches,
0,
"num_microbatches must divide the batch size.",
)
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
nrows = num_microbatches
row_indices = tf.cast(
tf.math.floordiv(row_indices, microbatch_size), tf.int64
)
# NOTE: expected values for the running example above are
# row_indices = [[0], [0], [0], [1], [2], [2]]
flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1)
paired_indices = tf.concat(
[tf.cast(row_indices, tf.int64), tf.cast(flattened_indices, tf.int64)],
axis=1,
)
# NOTE: expected values for the running example above are
# flattened_indices = [[1], [1], [2], [0], [2], [0]]
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
transform = tf.cast(tf.stack([[vocab_size], [1]], axis=0), tf.int64)
linearized_pair_indices = tf.reshape(
tf.matmul(paired_indices, transform), (-1,)
)
contribution_counts_linearized_indices, _, contribution_counts_values = (
tf.unique_with_counts(linearized_pair_indices)
)
contribution_counts_indices = tf.stack(
[
contribution_counts_linearized_indices // vocab_size,
contribution_counts_linearized_indices % vocab_size,
],
axis=1,
)
contribution_counts = tf.sparse.SparseTensor(
contribution_counts_indices,
contribution_counts_values,
(nrows, vocab_size),
)
contribution_counts = tf.sparse.reorder(contribution_counts)
# NOTE: expected values for the running example above are
# linearized_pair_indices = [1 1 2 3 8 6]
# contribution_counts_linearized_indices = [1 2 3 8 6]
# contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
# contribution_counts_values = [2 1 1 1 1]
user_normalized_contribution_counts = (
contribution_counts
/ tf.sparse.reduce_sum(contribution_counts, axis=-1, keepdims=True)
)
contribution_histogram = tf.sparse.reduce_sum(
user_normalized_contribution_counts, axis=0, output_is_sparse=True
)
# NOTE: expected values for the running example above are
# user_normalized_contribution_counts = tf.SparseTensor(
# indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]],
# values=[0.67, 0.33, 1., 0.5, 0.5,]
# shape=(3, 3)
# )
# contribution_histogram = tf.SparseTensor(
# indices=[[0], [1], [2]],
# values=[1.5, 0.67, 0.83],
# shape=(3,)
# )
return tf.sparse.reshape(contribution_histogram, (-1,))

View file

@ -0,0 +1,270 @@
# Copyright 2024, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for embedding."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise.registry_functions import embedding
class EmbeddingLayerWithMultipleTrainableVariables(tf.keras.layers.Embedding):
def build(self, input_shape):
self.some_other_variable = self.add_weight(
name="some_other_variable",
shape=(10, 10),
trainable=True,
)
super().build(input_shape)
class EmbeddingTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name="rank2_input",
input_ids=tf.constant([[0], [0], [4], [2]]),
num_microbatches=None,
vocab_size=8,
expected_contribution_counts=tf.SparseTensor(
indices=[[0], [2], [4]],
values=tf.constant([2.0, 1.0, 1.0], dtype=tf.float64),
dense_shape=[8],
),
),
dict(
testcase_name="rank2_multi_input",
input_ids=tf.constant([[0, 2], [0, 2], [4, 5], [2, 3]]),
num_microbatches=None,
vocab_size=8,
expected_contribution_counts=tf.SparseTensor(
indices=[[0], [2], [3], [4], [5]],
values=tf.constant([1.0, 1.5, 0.5, 0.5, 0.5], dtype=tf.float64),
dense_shape=[8],
),
),
dict(
testcase_name="rank3_input",
input_ids=tf.constant(
[[[0], [2]], [[0], [2]], [[4], [5]], [[2], [3]]]
),
num_microbatches=None,
vocab_size=8,
expected_contribution_counts=tf.SparseTensor(
indices=[[0], [2], [3], [4], [5]],
values=tf.constant([1.0, 1.5, 0.5, 0.5, 0.5], dtype=tf.float64),
dense_shape=[8],
),
),
dict(
testcase_name="ragged_input",
input_ids=tf.ragged.constant([[0, 2], [2], [2, 3, 4], [4, 5]]),
num_microbatches=None,
vocab_size=8,
expected_contribution_counts=tf.SparseTensor(
indices=[[0], [2], [3], [4], [5]],
values=tf.constant(
[0.5, 1.5 + 1.0 / 3, 1.0 / 3, 0.5 + 1.0 / 3, 0.5],
dtype=tf.float64,
),
dense_shape=[8],
),
),
dict(
testcase_name="rank2_input_num_microbatches_2",
input_ids=tf.constant([[0], [0], [4], [2]]),
num_microbatches=2,
vocab_size=8,
expected_contribution_counts=tf.SparseTensor(
indices=[[0], [2], [4]],
values=tf.constant([1.0, 0.5, 0.5], dtype=tf.float64),
dense_shape=[8],
),
),
dict(
testcase_name="ragged_input_num_microbatches_2",
input_ids=tf.ragged.constant([[0, 2], [2], [2, 3, 4], [4, 5]]),
num_microbatches=2,
vocab_size=8,
expected_contribution_counts=tf.SparseTensor(
indices=[[0], [2], [3], [4], [5]],
values=tf.constant(
[1.0 / 3, 2.0 / 3 + 1.0 / 5, 1.0 / 5, 2.0 / 5, 1.0 / 5],
dtype=tf.float64,
),
dense_shape=[8],
),
),
)
def test_embedding_layer_contribution_histogram_fn(
self,
input_ids,
expected_contribution_counts,
vocab_size,
num_microbatches,
):
grad = None
contribution_counts = embedding.embedding_layer_contribution_histogram_fn(
grad, input_ids, vocab_size, num_microbatches
)
tf.debugging.assert_equal(
tf.sparse.to_dense(contribution_counts),
tf.sparse.to_dense(expected_contribution_counts),
)
@parameterized.named_parameters(
dict(
testcase_name="input_None",
input_ids=None,
),
dict(
testcase_name="input_SparseTensor",
input_ids=tf.SparseTensor(
indices=[[0, 0]],
values=[0],
dense_shape=(3, 8),
),
),
dict(
testcase_name="input_list",
input_ids=[[0], [0], [1], [2]],
),
dict(
testcase_name="num_microbatches_not_divisible",
input_ids=tf.constant([[0], [0], [4], [2]]),
num_microbatches=3,
),
)
def test_embedding_layer_contribution_histogram_fn_errors(
self, input_ids, num_microbatches=None
):
with self.assertRaises(
(NotImplementedError, ValueError, tf.errors.InvalidArgumentError)
):
embedding.embedding_layer_contribution_histogram_fn(
None, input_ids, 8, num_microbatches
)
@parameterized.named_parameters(
dict(
testcase_name="input_kwargs",
error_message="Embedding layer calls should not receive kwargs.",
input_kwargs={"foo": "bar"},
),
dict(
testcase_name="input_args_more_than_one",
error_message="Only layer inputs of length 1 are permitted.",
input_args=[tf.constant([0]), tf.constant([0])],
),
dict(
testcase_name="input_args_none",
error_message="Only layer inputs of length 1 are permitted.",
),
dict(
testcase_name="input_sparse",
error_message="Sparse input tensors are not supported.",
input_args=[
tf.SparseTensor(indices=[[0, 0]], values=[0], dense_shape=(3, 8))
],
),
dict(
testcase_name="layer_one_hot_matmul",
error_message=(
"The experimental embedding feature '_use_one_hot_matmul' is not"
" supported."
),
input_args=[tf.constant([0])],
layer_kwargs={"use_one_hot_matmul": True},
),
dict(
testcase_name="layer_sparse",
error_message="Sparse output tensors are not supported.",
input_args=[tf.constant([0])],
layer_kwargs={"sparse": True},
),
)
def test_embedding_layer_contribution_histogram_errors(
self,
error_message,
input_args=None,
input_kwargs=None,
layer_kwargs=None,
):
layer_kwargs = layer_kwargs or {}
layer = tf.keras.layers.Embedding(input_dim=8, output_dim=4, **layer_kwargs)
layer.build(input_shape=(None, 1))
with self.assertRaisesRegex(
(NotImplementedError, ValueError),
error_message,
):
embedding.embedding_layer_contribution_histogram(
layer, input_args, input_kwargs, None
)
@parameterized.named_parameters(
dict(
testcase_name="zero_variables",
build=False, # unbuilt layer has no trainable variables
),
dict(
testcase_name="two_variables",
build=True,
),
)
def test_embedding_layer_contribution_histogram_embedding_layer_invalid_trainable_variables(
self, build
):
input_args = [tf.constant([0])]
input_kwargs = {}
layer = EmbeddingLayerWithMultipleTrainableVariables(
input_dim=8, output_dim=4
)
if build:
layer.build(input_shape=(None, 1))
with self.assertRaisesRegex(
ValueError,
"Embedding layer must have exactly one trainable variable.",
):
embedding.embedding_layer_contribution_histogram(
layer, input_args, input_kwargs
)
def test_embedding_layer_contribution_histogram_embedding(self):
input_args = [tf.constant([[0], [0], [4], [2]])]
input_kwargs = {}
layer = tf.keras.layers.Embedding(input_dim=8, output_dim=4)
layer.build(input_shape=(None, 1))
contribution_counts_fn_dict = (
embedding.embedding_layer_contribution_histogram(
layer, input_args, input_kwargs
)
)
self.assertEqual(list(contribution_counts_fn_dict.keys()), ["embeddings:0"])
contribution_counts_fn = contribution_counts_fn_dict["embeddings:0"]
dummy_gradient = None
contribution_counts = contribution_counts_fn(dummy_gradient)
expected_contribution_counts = tf.SparseTensor(
indices=[[0], [2], [4]],
values=tf.constant([2.0, 1.0, 1.0], dtype=tf.float64),
dense_shape=[8],
)
tf.debugging.assert_equal(
tf.sparse.to_dense(contribution_counts),
tf.sparse.to_dense(expected_contribution_counts),
)
if __name__ == "__main__":
tf.test.main()

View file

@ -27,5 +27,5 @@ ContributionCountHistogramFn = Callable[
NumMicrobatches = int | tf.Tensor
SparsityPreservingNoiseLayerRegistryFunction = Callable[
[tf.keras.layers.Layer, InputArgs, InputKwargs, NumMicrobatches | None],
ContributionCountHistogramFn,
dict[str, ContributionCountHistogramFn],
]