Sparsity Preserving DP-SGD in TF Privacy

Add support for adding sparsity preserving noise in add_aggregate_noise

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 662148309
This commit is contained in:
A. Unique TensorFlower 2024-08-12 10:44:57 -07:00
parent 09c68750d7
commit e42b574465
3 changed files with 172 additions and 10 deletions

View file

@ -84,6 +84,7 @@ py_library(
py_library( py_library(
name = "noise_utils", name = "noise_utils",
srcs = ["noise_utils.py"], srcs = ["noise_utils.py"],
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"],
) )
py_test( py_test(

View file

@ -14,10 +14,21 @@
"""Utility functions that help in adding noise to gradients.""" """Utility functions that help in adding noise to gradients."""
from collections.abc import Sequence from collections.abc import Sequence
import dataclasses
from typing import Literal, Optional from typing import Literal, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
@dataclasses.dataclass
class SparsityPreservingNoiseConfig:
"""Configuration for adding noise to gradients."""
sparse_noise_multiplier: float = 0.0
sparse_selection_threshold: int = 0
sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None
def _infer_loss_reduction_type(model: tf.keras.Model): def _infer_loss_reduction_type(model: tf.keras.Model):
@ -44,21 +55,53 @@ def _infer_loss_reduction_type(model: tf.keras.Model):
) )
def _add_dense_aggregate_noise(
grad: tf.Tensor,
noise_multiplier: float,
sensitivity: float,
) -> tf.Tensor:
"""Adds dense noise to a dense gradient."""
return grad + tf.random.normal(
tf.shape(grad), mean=0.0, stddev=noise_multiplier * sensitivity
)
def _add_sparse_aggregate_noise(
grad: tf.IndexedSlices,
contribution_counts: tf.SparseTensor,
noise_multiplier: float,
noise_multiplier_sparse: float,
sensitivity: float,
sparse_selection_threshold: int,
) -> tf.IndexedSlices:
"""Adds sparse noise to a sparse gradient."""
return sparse_noise_utils.add_sparse_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=noise_multiplier_sparse,
l2_norm_clip=sensitivity,
threshold=sparse_selection_threshold,
)
def add_aggregate_noise( def add_aggregate_noise(
clipped_grads: list[tf.Tensor], clipped_grads: list[tf.Tensor | tf.IndexedSlices],
batch_size: tf.Tensor, batch_size: tf.Tensor,
l2_norm_clip: float, l2_norm_clip: float,
noise_multiplier: float, noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None, loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None, loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]: sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None,
) -> Sequence[tf.Tensor | tf.IndexedSlices]:
"""Adds noise to a collection of clipped gradients. """Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the The magnitude of the noise depends on the aggregation strategy of the
input model's loss function. input model's loss function.
Args: Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing
the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'. `loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient). l2_norm_clip: Clipping norm (max L2 norm of each gradient).
@ -68,11 +111,14 @@ def add_aggregate_noise(
aggregation type must be inferred from `input_model.loss`. aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`. strategy from if `loss_reduction` is `None`.
sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing
the configuration for adding sparse noise. If None, all noise added is
dense.
Returns: Returns:
A list of tensors containing the clipped gradients, but with the right A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction amount of Gaussian or sparse Gaussain noise added to them (depending on
strategy of the loss function). the reduction strategy of the loss function).
Raises: Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if ValueError: If both `loss_model` and `loss_reduction` are `None` or if
@ -103,13 +149,36 @@ def add_aggregate_noise(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
) )
if sparse_noise_config is None:
sparse_contribution_counts = tf.nest.map_structure(
lambda x: None, clipped_grads
)
else:
sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts
scale = l2_norm_clip scale = l2_norm_clip
if loss_reduction == 'mean': if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32) scale /= tf.cast(batch_size, tf.float32)
def add_noise(g): def add_noise(grad, contribution_counts):
return g + tf.random.normal( if (
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale sparse_noise_config is not None
) and isinstance(grad, tf.IndexedSlices)
and contribution_counts is not None
):
return _add_sparse_aggregate_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=sparse_noise_config.sparse_noise_multiplier,
sensitivity=scale,
sparse_selection_threshold=sparse_noise_config.sparse_selection_threshold,
)
else:
return _add_dense_aggregate_noise(
grad=grad, noise_multiplier=noise_multiplier, sensitivity=scale
)
return tf.nest.map_structure(add_noise, clipped_grads) return tf.nest.map_structure(
add_noise, clipped_grads, sparse_contribution_counts
)

View file

@ -70,3 +70,95 @@ class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
computed_std = np.std(noised_grads[0] - clipped_grads[0]) computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std) self.assertNear(computed_std, expected_std, 0.1 * expected_std)
@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
sparse_noise_multiplier=[1.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_sparse_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
sparse_noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
sparse_grad = tf.IndexedSlices(
values=tf.ones((3, 4)),
indices=tf.constant([0, 3, 5]),
dense_shape=tf.constant([8, 4]),
)
sparse_grad_contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5]],
values=[10.0, 10.0, 20.0],
dense_shape=[8],
)
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
sparse_noise_multiplier=sparse_noise_multiplier,
sparse_selection_threshold=8,
sparse_contribution_counts=[None, sparse_grad_contribution_counts],
)
sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise(
clipped_grads=[dense_grad, sparse_grad],
batch_size=batch_size,
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
loss_model=linear_model,
sparse_noise_config=sparse_noise_config,
)
self.assertContainsSubset(
sparse_grad.indices.numpy().tolist(),
sparse_noised_grad.indices.numpy().tolist(),
)
sparse_noised_grad_dense = tf.scatter_nd(
tf.reshape(sparse_noised_grad.indices, (-1, 1)),
sparse_noised_grad.values,
shape=(8, 4),
).numpy()
sparse_noised_grad_valid_indices = sparse_noised_grad_dense[
sparse_grad.indices.numpy()
]
sparse_grad_values = sparse_grad.values.numpy()
self.assertTrue(
np.all(
np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values)
)
)
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
# The only measure that varies is the standard deviation of the variation.
expected_std = l2_norm_clip * noise_multiplier * scale
sparse_computed_std = np.std(
sparse_noised_grad_valid_indices - sparse_grad_values
)
self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std)
dense_computed_std = np.std(dense_noised_grad - dense_grad)
self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std)