Sparsity Preserving DP-SGD in TF Privacy [3 of 5]

Adds sparse noise utilities to privately select sparse indices from and add sparse noise to gradients.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 654902527
This commit is contained in:
A. Unique TensorFlower 2024-07-22 14:46:20 -07:00
parent 8747858b5b
commit a56f33c4c5
2 changed files with 342 additions and 1 deletions

View file

@ -16,8 +16,9 @@
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
"""
from typing import Optional, Sequence
from typing import Mapping, Optional, Sequence
from scipy import stats
import tensorflow as tf
import tensorflow_probability as tfp
@ -169,3 +170,191 @@ def sample_true_positive_indices(
noised_contribution_count_values >= threshold
][:, 0]
return tf.reshape(noised_contribution_counts_indices, (-1,))
@tf.function
def _remap_indices(
indices: tf.Tensor,
skip_indices: tf.Tensor,
) -> tf.Tensor:
"""Remaps the indices while skipping the skip indices.
As an example, if skip_indices = [1, 3], then the indices will be remapped as
follows:
0 -> 0
1 -> 2
2 -> 4
3 -> 5
4 -> 6
5 -> 7
...
This is useful for merging the true positive and false positive indices.
Args:
indices: The indices to remap.
skip_indices: The indices to skip while remapping. Assumed to be sorted.
Returns:
The remapped indices.
"""
def piecewise_map(skip_indices):
map_counts = tf.range(tf.size(skip_indices) + 1, dtype=tf.int64)
map_idx = skip_indices - map_counts[:-1]
skip_indices = tf.concat([[-1], skip_indices], axis=0)
gaps = skip_indices[1:] - skip_indices[:-1] - 1
map_idx = tf.concat([map_idx, [tf.int64.max]], axis=0)
gaps = tf.concat([gaps, [1]], axis=0)
return map_idx[gaps > 0], map_counts[gaps > 0]
map_idx, map_count = piecewise_map(skip_indices)
idx = tf.searchsorted(map_idx, indices, side='right')
offset = tf.gather(map_count, idx)
return indices + offset
def sparse_private_partition_selection(
contribution_counts: tf.SparseTensor,
noise_multiplier: float,
threshold: int,
) -> tf.Tensor:
"""Differentially private partition selection.
Uses the sparse sampling algorithm to sample false positive indices. Also
assumes that the contribution counts are clipped to a per example contribution
of 1.
Args:
contribution_counts: The contribution counts for each index.
noise_multiplier: The noise multiplier to use for the gaussian noise.
threshold: The threshold to use for the selection.
Returns:
A tensor of selected indices.
"""
if threshold < 0:
raise ValueError(f'Threshold must be positive, got {threshold}.')
true_positive_indices = sample_true_positive_indices(
contribution_counts, noise_multiplier, threshold
)
if noise_multiplier <= 0.0:
return true_positive_indices
# probability of selecting an index with zero contribution count.
prob = stats.norm.sf(threshold / noise_multiplier).item()
num_total_indices = tf.cast(contribution_counts.dense_shape[0], tf.int32)
num_non_zero_indices = tf.shape(contribution_counts.values)[0]
max_index = tf.cast(num_total_indices - num_non_zero_indices - 1, tf.int32)
false_positive_indices = sample_false_positive_indices(max_index, prob)
remapped_false_positive_indices = _remap_indices(
false_positive_indices, tf.reshape(contribution_counts.indices, (-1,))
)
merged_indices = tf.sort(
tf.concat(
[remapped_false_positive_indices, true_positive_indices], axis=0
)
)
return merged_indices
def add_sparse_gradient_noise(
grad: tf.IndexedSlices, indices: tf.Tensor, noise_stddev: float
) -> tf.IndexedSlices:
"""Adds sparse gradient noise.
Args:
grad: A sparse gradient of type `tf.IndexedSlices`.
indices: The selected indices to keep.
noise_stddev: The standard deviation of the noise to add.
Returns:
A sparse gradient of type `tf.IndexedSlices` with the noise added.
"""
filtered_grad_values = tf.gather(grad, indices)
sparse_noise_values = tf.random.normal(
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev
)
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
return tf.IndexedSlices(
indices=indices,
values=filtered_noised_grad_values,
dense_shape=grad.dense_shape,
)
def get_contribution_counts(
trainable_vars: list[tf.Variable],
grads: list[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
) -> list[tf.Tensor | None]:
"""Gets the contribution counts for each variable in the Model.
Args:
trainable_vars: A list of the trainable variables in the Model.
grads: A corresponding list of gradients for each trainable variable.
varname_to_contribution_counts_fns: A mapping from variable name to a list
of functions to get the contribution counts for that variable.
Returns:
A list of contribution counts for each variable and None for variables that
do not have contribution counts function.
Raises:
NotImplementedError: If there are more than one contribution counts function
for a variable.
"""
contribution_counts_list = []
for var, grad in zip(trainable_vars, grads):
if var.name not in varname_to_contribution_counts_fns:
contribution_counts_list.append(None)
continue
contribution_counts_fns = varname_to_contribution_counts_fns[var.name]
if not contribution_counts_fns or not contribution_counts_fns[0]:
contribution_counts_list.append(None)
continue
if len(contribution_counts_fns) > 1:
raise NotImplementedError(
'Sparse noise is not supported for shared weight variables.'
)
contribution_counts_fn = contribution_counts_fns[0]
contribution_counts = contribution_counts_fn(grad)
contribution_counts_list.append(contribution_counts)
return contribution_counts_list
def add_sparse_noise(
grad: tf.IndexedSlices,
contribution_counts: tf.SparseTensor,
noise_multiplier: float,
noise_multiplier_sparse: float,
l2_norm_clip: float,
threshold: int,
) -> tf.IndexedSlices:
"""Adds sparse noise to a gradient.
Args:
grad: A sparse gradient of type `tf.IndexedSlices`.
contribution_counts: The contribution counts for each index of grad.
noise_multiplier: The noise multiplier to use for the gradient noise.
noise_multiplier_sparse: The noise multiplier to use for the partition
selection.
l2_norm_clip: The l2 norm clip at which the gradient is clipped.
threshold: The threshold to use for the partition selection.
Returns:
A sparse gradient of type `tf.IndexedSlices` with the noise added.
"""
privately_selected_indices = sparse_private_partition_selection(
contribution_counts, noise_multiplier_sparse, threshold
)
noised_grad = add_sparse_gradient_noise(
grad, privately_selected_indices, noise_multiplier * l2_norm_clip
)
return noised_grad

View file

@ -284,6 +284,158 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(batch_size, 0)
self.assertLess(batch_size, max_index + 1)
def test_sparse_private_partition_selection_without_noise(self):
contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5]],
values=[2.0, 1.0, 1.0],
dense_shape=[8],
)
noise_multiplier = 0.0
threshold = 2
sampled_indices = (
sparse_noise_utils.sparse_private_partition_selection(
contribution_counts, noise_multiplier, threshold
)
.numpy()
.tolist()
)
expected_indices = [0]
self.assertEqual(sampled_indices, expected_indices)
def test_sparse_private_partition_selection_with_noise(self):
contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5]],
values=[50.0, 1.0, 1.0],
dense_shape=[1000],
)
noise_multiplier = 1.0
threshold = 1
sampled_indices = (
sparse_noise_utils.sparse_private_partition_selection(
contribution_counts, noise_multiplier, threshold
)
.numpy()
.tolist()
)
expected_indices = [0]
self.assertContainsSubset(expected_indices, sampled_indices)
self.assertGreater(len(sampled_indices), 1)
def test_remap_indices(self):
expected_indices = [4, 9, 14]
indices = tf.constant([1, 5, 10], tf.int64)
skip_indices = tf.constant([0, 1, 2, 5], tf.int64)
remapped_indices = sparse_noise_utils._remap_indices(indices, skip_indices)
self.assertEqual(remapped_indices.numpy().tolist(), expected_indices)
def test_remap_indices_no_skip(self):
expected_indices = [1, 5, 10]
indices = tf.constant([1, 5, 10], tf.int64)
skip_indices = tf.constant([], tf.int64)
remapped_indices = sparse_noise_utils._remap_indices(indices, skip_indices)
self.assertEqual(remapped_indices.numpy().tolist(), expected_indices)
def test_add_sparse_gradient_noise(self):
grad = tf.IndexedSlices(
values=tf.ones((1, 2)),
indices=tf.constant([0]),
dense_shape=tf.constant([2, 2]),
)
indices = tf.constant([1], dtype=tf.int64)
noise_stddev = 1.0
noised_grad = sparse_noise_utils.add_sparse_gradient_noise(
grad, indices, noise_stddev
)
self.assertListEqual(
noised_grad.indices.numpy().tolist(), indices.numpy().tolist()
)
one_index_values = noised_grad.values[0].numpy().tolist()
self.assertNotEqual(one_index_values, [0.0, 0.0])
def test_get_contribution_counts(self):
trainable_vars = [
tf.Variable(tf.ones((1, 2)), name='var1'),
tf.Variable(tf.ones((1, 2)), name='var2'),
tf.Variable(tf.ones((1, 2)), name='var3'),
]
grads = [
tf.IndexedSlices(
values=tf.ones((1, 2)),
indices=tf.constant([0]),
dense_shape=tf.constant([2, 2]),
),
tf.ones((1, 2)),
tf.ones((1, 2)),
]
varname_to_contribution_counts_fns = {
'var1:0': [lambda grad: 1.0],
'var2:0': None,
}
contribution_counts = sparse_noise_utils.get_contribution_counts(
trainable_vars, grads, varname_to_contribution_counts_fns
)
expected_contribution_counts = [1.0, None, None]
self.assertEqual(contribution_counts, expected_contribution_counts)
def test_add_sparse_noise_without_noise(self):
grad = tf.IndexedSlices(
values=tf.ones((3, 4)),
indices=tf.constant([0, 3, 5]),
dense_shape=tf.constant([8, 4]),
)
contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5]],
values=[3.0, 1.0, 2.0],
dense_shape=[8],
)
noised_grad = sparse_noise_utils.add_sparse_noise(
grad,
contribution_counts,
noise_multiplier=0.0,
noise_multiplier_sparse=0.0,
l2_norm_clip=1.0,
threshold=1,
)
self.assertEqual(
noised_grad.indices.numpy().tolist(), grad.indices.numpy().tolist()
)
self.assertEqual(
noised_grad.values.numpy().tolist(), grad.values.numpy().tolist()
)
def test_add_sparse_noise_with_noise(self):
grad = tf.IndexedSlices(
values=tf.ones((3, 4)),
indices=tf.constant([0, 3, 5]),
dense_shape=tf.constant([8, 4]),
)
contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5]],
values=[10.0, 10.0, 20.0],
dense_shape=[8],
)
noised_grad = sparse_noise_utils.add_sparse_noise(
grad,
contribution_counts,
noise_multiplier=1.0,
noise_multiplier_sparse=1.0,
l2_norm_clip=1.0,
threshold=5,
)
self.assertContainsSubset(
grad.indices.numpy().tolist(),
noised_grad.indices.numpy().tolist(),
)
noised_grad_dense = tf.scatter_nd(
tf.reshape(noised_grad.indices, (-1, 1)),
noised_grad.values,
shape=(8, 4),
).numpy()
noised_grad_valid_indices = noised_grad_dense[grad.indices.numpy()]
self.assertTrue(
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
)
if __name__ == '__main__':
tf.test.main()