Sparsity Preserving DP-SGD in TF Privacy
Add support for adding sparsity preserving noise in add_aggregate_noise See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 662148309
This commit is contained in:
parent
09c68750d7
commit
e42b574465
3 changed files with 172 additions and 10 deletions
|
@ -84,6 +84,7 @@ py_library(
|
||||||
py_library(
|
py_library(
|
||||||
name = "noise_utils",
|
name = "noise_utils",
|
||||||
srcs = ["noise_utils.py"],
|
srcs = ["noise_utils.py"],
|
||||||
|
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -14,10 +14,21 @@
|
||||||
"""Utility functions that help in adding noise to gradients."""
|
"""Utility functions that help in adding noise to gradients."""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
import dataclasses
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SparsityPreservingNoiseConfig:
|
||||||
|
"""Configuration for adding noise to gradients."""
|
||||||
|
|
||||||
|
sparse_noise_multiplier: float = 0.0
|
||||||
|
sparse_selection_threshold: int = 0
|
||||||
|
sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
def _infer_loss_reduction_type(model: tf.keras.Model):
|
def _infer_loss_reduction_type(model: tf.keras.Model):
|
||||||
|
@ -44,21 +55,53 @@ def _infer_loss_reduction_type(model: tf.keras.Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_dense_aggregate_noise(
|
||||||
|
grad: tf.Tensor,
|
||||||
|
noise_multiplier: float,
|
||||||
|
sensitivity: float,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Adds dense noise to a dense gradient."""
|
||||||
|
return grad + tf.random.normal(
|
||||||
|
tf.shape(grad), mean=0.0, stddev=noise_multiplier * sensitivity
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_sparse_aggregate_noise(
|
||||||
|
grad: tf.IndexedSlices,
|
||||||
|
contribution_counts: tf.SparseTensor,
|
||||||
|
noise_multiplier: float,
|
||||||
|
noise_multiplier_sparse: float,
|
||||||
|
sensitivity: float,
|
||||||
|
sparse_selection_threshold: int,
|
||||||
|
) -> tf.IndexedSlices:
|
||||||
|
"""Adds sparse noise to a sparse gradient."""
|
||||||
|
return sparse_noise_utils.add_sparse_noise(
|
||||||
|
grad=grad,
|
||||||
|
contribution_counts=contribution_counts,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
noise_multiplier_sparse=noise_multiplier_sparse,
|
||||||
|
l2_norm_clip=sensitivity,
|
||||||
|
threshold=sparse_selection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_aggregate_noise(
|
def add_aggregate_noise(
|
||||||
clipped_grads: list[tf.Tensor],
|
clipped_grads: list[tf.Tensor | tf.IndexedSlices],
|
||||||
batch_size: tf.Tensor,
|
batch_size: tf.Tensor,
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
noise_multiplier: float,
|
noise_multiplier: float,
|
||||||
loss_reduction: Optional[Literal['mean', 'sum']] = None,
|
loss_reduction: Optional[Literal['mean', 'sum']] = None,
|
||||||
loss_model: Optional[tf.keras.Model] = None,
|
loss_model: Optional[tf.keras.Model] = None,
|
||||||
) -> Sequence[tf.Tensor]:
|
sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None,
|
||||||
|
) -> Sequence[tf.Tensor | tf.IndexedSlices]:
|
||||||
"""Adds noise to a collection of clipped gradients.
|
"""Adds noise to a collection of clipped gradients.
|
||||||
|
|
||||||
The magnitude of the noise depends on the aggregation strategy of the
|
The magnitude of the noise depends on the aggregation strategy of the
|
||||||
input model's loss function.
|
input model's loss function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing
|
||||||
|
the clipped gradients.
|
||||||
batch_size: The batch size. Used for normalizing the noise when
|
batch_size: The batch size. Used for normalizing the noise when
|
||||||
`loss_reduction` is 'sum'.
|
`loss_reduction` is 'sum'.
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||||
|
@ -68,11 +111,14 @@ def add_aggregate_noise(
|
||||||
aggregation type must be inferred from `input_model.loss`.
|
aggregation type must be inferred from `input_model.loss`.
|
||||||
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
|
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
|
||||||
strategy from if `loss_reduction` is `None`.
|
strategy from if `loss_reduction` is `None`.
|
||||||
|
sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing
|
||||||
|
the configuration for adding sparse noise. If None, all noise added is
|
||||||
|
dense.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of tensors containing the clipped gradients, but with the right
|
A list of tensors containing the clipped gradients, but with the right
|
||||||
amount of Gaussian noise added to them (depending on the reduction
|
amount of Gaussian or sparse Gaussain noise added to them (depending on
|
||||||
strategy of the loss function).
|
the reduction strategy of the loss function).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
|
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
|
||||||
|
@ -103,13 +149,36 @@ def add_aggregate_noise(
|
||||||
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
|
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sparse_noise_config is None:
|
||||||
|
sparse_contribution_counts = tf.nest.map_structure(
|
||||||
|
lambda x: None, clipped_grads
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts
|
||||||
|
|
||||||
scale = l2_norm_clip
|
scale = l2_norm_clip
|
||||||
if loss_reduction == 'mean':
|
if loss_reduction == 'mean':
|
||||||
scale /= tf.cast(batch_size, tf.float32)
|
scale /= tf.cast(batch_size, tf.float32)
|
||||||
|
|
||||||
def add_noise(g):
|
def add_noise(grad, contribution_counts):
|
||||||
return g + tf.random.normal(
|
if (
|
||||||
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
|
sparse_noise_config is not None
|
||||||
|
and isinstance(grad, tf.IndexedSlices)
|
||||||
|
and contribution_counts is not None
|
||||||
|
):
|
||||||
|
return _add_sparse_aggregate_noise(
|
||||||
|
grad=grad,
|
||||||
|
contribution_counts=contribution_counts,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
noise_multiplier_sparse=sparse_noise_config.sparse_noise_multiplier,
|
||||||
|
sensitivity=scale,
|
||||||
|
sparse_selection_threshold=sparse_noise_config.sparse_selection_threshold,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _add_dense_aggregate_noise(
|
||||||
|
grad=grad, noise_multiplier=noise_multiplier, sensitivity=scale
|
||||||
)
|
)
|
||||||
|
|
||||||
return tf.nest.map_structure(add_noise, clipped_grads)
|
return tf.nest.map_structure(
|
||||||
|
add_noise, clipped_grads, sparse_contribution_counts
|
||||||
|
)
|
||||||
|
|
|
@ -70,3 +70,95 @@ class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
computed_std = np.std(noised_grads[0] - clipped_grads[0])
|
computed_std = np.std(noised_grads[0] - clipped_grads[0])
|
||||||
expected_std = l2_norm_clip * noise_multiplier * scale
|
expected_std = l2_norm_clip * noise_multiplier * scale
|
||||||
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
|
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
l2_norm_clip=[3.0, 5.0],
|
||||||
|
noise_multiplier=[2.0, 4.0],
|
||||||
|
sparse_noise_multiplier=[1.0],
|
||||||
|
batch_size=[1, 2, 10],
|
||||||
|
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
|
||||||
|
noise_fn_reduction=[None, 'mean', 'sum'],
|
||||||
|
)
|
||||||
|
def test_sparse_noise_is_computed_correctly(
|
||||||
|
self,
|
||||||
|
l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
sparse_noise_multiplier,
|
||||||
|
batch_size,
|
||||||
|
model_fn_reduction,
|
||||||
|
noise_fn_reduction,
|
||||||
|
):
|
||||||
|
# Skip invalid combinations.
|
||||||
|
if model_fn_reduction is None and noise_fn_reduction is None:
|
||||||
|
return
|
||||||
|
if model_fn_reduction is not None and noise_fn_reduction is not None:
|
||||||
|
return
|
||||||
|
# Make an simple model container for storing the loss.
|
||||||
|
if model_fn_reduction is not None:
|
||||||
|
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
|
||||||
|
linear_model.compile(
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
linear_model = None
|
||||||
|
# The main computation is done on a deterministic dummy vector.
|
||||||
|
num_units = 100
|
||||||
|
dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
|
||||||
|
sparse_grad = tf.IndexedSlices(
|
||||||
|
values=tf.ones((3, 4)),
|
||||||
|
indices=tf.constant([0, 3, 5]),
|
||||||
|
dense_shape=tf.constant([8, 4]),
|
||||||
|
)
|
||||||
|
sparse_grad_contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5]],
|
||||||
|
values=[10.0, 10.0, 20.0],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
|
||||||
|
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
|
||||||
|
sparse_noise_multiplier=sparse_noise_multiplier,
|
||||||
|
sparse_selection_threshold=8,
|
||||||
|
sparse_contribution_counts=[None, sparse_grad_contribution_counts],
|
||||||
|
)
|
||||||
|
|
||||||
|
sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise(
|
||||||
|
clipped_grads=[dense_grad, sparse_grad],
|
||||||
|
batch_size=batch_size,
|
||||||
|
l2_norm_clip=l2_norm_clip,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
loss_model=linear_model,
|
||||||
|
sparse_noise_config=sparse_noise_config,
|
||||||
|
)
|
||||||
|
self.assertContainsSubset(
|
||||||
|
sparse_grad.indices.numpy().tolist(),
|
||||||
|
sparse_noised_grad.indices.numpy().tolist(),
|
||||||
|
)
|
||||||
|
sparse_noised_grad_dense = tf.scatter_nd(
|
||||||
|
tf.reshape(sparse_noised_grad.indices, (-1, 1)),
|
||||||
|
sparse_noised_grad.values,
|
||||||
|
shape=(8, 4),
|
||||||
|
).numpy()
|
||||||
|
sparse_noised_grad_valid_indices = sparse_noised_grad_dense[
|
||||||
|
sparse_grad.indices.numpy()
|
||||||
|
]
|
||||||
|
sparse_grad_values = sparse_grad.values.numpy()
|
||||||
|
self.assertTrue(
|
||||||
|
np.all(
|
||||||
|
np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scale = (
|
||||||
|
1.0
|
||||||
|
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
|
||||||
|
else 1.0 / batch_size
|
||||||
|
)
|
||||||
|
# The only measure that varies is the standard deviation of the variation.
|
||||||
|
expected_std = l2_norm_clip * noise_multiplier * scale
|
||||||
|
|
||||||
|
sparse_computed_std = np.std(
|
||||||
|
sparse_noised_grad_valid_indices - sparse_grad_values
|
||||||
|
)
|
||||||
|
self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
||||||
|
dense_computed_std = np.std(dense_noised_grad - dense_grad)
|
||||||
|
self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
Loading…
Reference in a new issue