forked from 626_privacy/tensorflow_privacy
Sparsity Preserving DP-SGD in TF Privacy [3 of 5]
Adds sparse noise utilities to privately select sparse indices from and add sparse noise to gradients. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 654902527
This commit is contained in:
parent
8747858b5b
commit
a56f33c4c5
2 changed files with 342 additions and 1 deletions
|
@ -16,8 +16,9 @@
|
||||||
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Sequence
|
from typing import Mapping, Optional, Sequence
|
||||||
|
|
||||||
|
from scipy import stats
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_probability as tfp
|
import tensorflow_probability as tfp
|
||||||
|
|
||||||
|
@ -169,3 +170,191 @@ def sample_true_positive_indices(
|
||||||
noised_contribution_count_values >= threshold
|
noised_contribution_count_values >= threshold
|
||||||
][:, 0]
|
][:, 0]
|
||||||
return tf.reshape(noised_contribution_counts_indices, (-1,))
|
return tf.reshape(noised_contribution_counts_indices, (-1,))
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def _remap_indices(
|
||||||
|
indices: tf.Tensor,
|
||||||
|
skip_indices: tf.Tensor,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Remaps the indices while skipping the skip indices.
|
||||||
|
|
||||||
|
As an example, if skip_indices = [1, 3], then the indices will be remapped as
|
||||||
|
follows:
|
||||||
|
0 -> 0
|
||||||
|
1 -> 2
|
||||||
|
2 -> 4
|
||||||
|
3 -> 5
|
||||||
|
4 -> 6
|
||||||
|
5 -> 7
|
||||||
|
...
|
||||||
|
|
||||||
|
This is useful for merging the true positive and false positive indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices: The indices to remap.
|
||||||
|
skip_indices: The indices to skip while remapping. Assumed to be sorted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The remapped indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def piecewise_map(skip_indices):
|
||||||
|
map_counts = tf.range(tf.size(skip_indices) + 1, dtype=tf.int64)
|
||||||
|
map_idx = skip_indices - map_counts[:-1]
|
||||||
|
skip_indices = tf.concat([[-1], skip_indices], axis=0)
|
||||||
|
gaps = skip_indices[1:] - skip_indices[:-1] - 1
|
||||||
|
|
||||||
|
map_idx = tf.concat([map_idx, [tf.int64.max]], axis=0)
|
||||||
|
gaps = tf.concat([gaps, [1]], axis=0)
|
||||||
|
|
||||||
|
return map_idx[gaps > 0], map_counts[gaps > 0]
|
||||||
|
|
||||||
|
map_idx, map_count = piecewise_map(skip_indices)
|
||||||
|
idx = tf.searchsorted(map_idx, indices, side='right')
|
||||||
|
offset = tf.gather(map_count, idx)
|
||||||
|
return indices + offset
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_private_partition_selection(
|
||||||
|
contribution_counts: tf.SparseTensor,
|
||||||
|
noise_multiplier: float,
|
||||||
|
threshold: int,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Differentially private partition selection.
|
||||||
|
|
||||||
|
Uses the sparse sampling algorithm to sample false positive indices. Also
|
||||||
|
assumes that the contribution counts are clipped to a per example contribution
|
||||||
|
of 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contribution_counts: The contribution counts for each index.
|
||||||
|
noise_multiplier: The noise multiplier to use for the gaussian noise.
|
||||||
|
threshold: The threshold to use for the selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of selected indices.
|
||||||
|
"""
|
||||||
|
if threshold < 0:
|
||||||
|
raise ValueError(f'Threshold must be positive, got {threshold}.')
|
||||||
|
|
||||||
|
true_positive_indices = sample_true_positive_indices(
|
||||||
|
contribution_counts, noise_multiplier, threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if noise_multiplier <= 0.0:
|
||||||
|
return true_positive_indices
|
||||||
|
|
||||||
|
# probability of selecting an index with zero contribution count.
|
||||||
|
prob = stats.norm.sf(threshold / noise_multiplier).item()
|
||||||
|
|
||||||
|
num_total_indices = tf.cast(contribution_counts.dense_shape[0], tf.int32)
|
||||||
|
num_non_zero_indices = tf.shape(contribution_counts.values)[0]
|
||||||
|
max_index = tf.cast(num_total_indices - num_non_zero_indices - 1, tf.int32)
|
||||||
|
false_positive_indices = sample_false_positive_indices(max_index, prob)
|
||||||
|
remapped_false_positive_indices = _remap_indices(
|
||||||
|
false_positive_indices, tf.reshape(contribution_counts.indices, (-1,))
|
||||||
|
)
|
||||||
|
merged_indices = tf.sort(
|
||||||
|
tf.concat(
|
||||||
|
[remapped_false_positive_indices, true_positive_indices], axis=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return merged_indices
|
||||||
|
|
||||||
|
|
||||||
|
def add_sparse_gradient_noise(
|
||||||
|
grad: tf.IndexedSlices, indices: tf.Tensor, noise_stddev: float
|
||||||
|
) -> tf.IndexedSlices:
|
||||||
|
"""Adds sparse gradient noise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grad: A sparse gradient of type `tf.IndexedSlices`.
|
||||||
|
indices: The selected indices to keep.
|
||||||
|
noise_stddev: The standard deviation of the noise to add.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sparse gradient of type `tf.IndexedSlices` with the noise added.
|
||||||
|
"""
|
||||||
|
filtered_grad_values = tf.gather(grad, indices)
|
||||||
|
sparse_noise_values = tf.random.normal(
|
||||||
|
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev
|
||||||
|
)
|
||||||
|
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
|
||||||
|
return tf.IndexedSlices(
|
||||||
|
indices=indices,
|
||||||
|
values=filtered_noised_grad_values,
|
||||||
|
dense_shape=grad.dense_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_contribution_counts(
|
||||||
|
trainable_vars: list[tf.Variable],
|
||||||
|
grads: list[tf.Tensor],
|
||||||
|
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
|
||||||
|
) -> list[tf.Tensor | None]:
|
||||||
|
"""Gets the contribution counts for each variable in the Model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainable_vars: A list of the trainable variables in the Model.
|
||||||
|
grads: A corresponding list of gradients for each trainable variable.
|
||||||
|
varname_to_contribution_counts_fns: A mapping from variable name to a list
|
||||||
|
of functions to get the contribution counts for that variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of contribution counts for each variable and None for variables that
|
||||||
|
do not have contribution counts function.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If there are more than one contribution counts function
|
||||||
|
for a variable.
|
||||||
|
"""
|
||||||
|
contribution_counts_list = []
|
||||||
|
for var, grad in zip(trainable_vars, grads):
|
||||||
|
if var.name not in varname_to_contribution_counts_fns:
|
||||||
|
contribution_counts_list.append(None)
|
||||||
|
continue
|
||||||
|
contribution_counts_fns = varname_to_contribution_counts_fns[var.name]
|
||||||
|
if not contribution_counts_fns or not contribution_counts_fns[0]:
|
||||||
|
contribution_counts_list.append(None)
|
||||||
|
continue
|
||||||
|
if len(contribution_counts_fns) > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Sparse noise is not supported for shared weight variables.'
|
||||||
|
)
|
||||||
|
contribution_counts_fn = contribution_counts_fns[0]
|
||||||
|
contribution_counts = contribution_counts_fn(grad)
|
||||||
|
contribution_counts_list.append(contribution_counts)
|
||||||
|
|
||||||
|
return contribution_counts_list
|
||||||
|
|
||||||
|
|
||||||
|
def add_sparse_noise(
|
||||||
|
grad: tf.IndexedSlices,
|
||||||
|
contribution_counts: tf.SparseTensor,
|
||||||
|
noise_multiplier: float,
|
||||||
|
noise_multiplier_sparse: float,
|
||||||
|
l2_norm_clip: float,
|
||||||
|
threshold: int,
|
||||||
|
) -> tf.IndexedSlices:
|
||||||
|
"""Adds sparse noise to a gradient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grad: A sparse gradient of type `tf.IndexedSlices`.
|
||||||
|
contribution_counts: The contribution counts for each index of grad.
|
||||||
|
noise_multiplier: The noise multiplier to use for the gradient noise.
|
||||||
|
noise_multiplier_sparse: The noise multiplier to use for the partition
|
||||||
|
selection.
|
||||||
|
l2_norm_clip: The l2 norm clip at which the gradient is clipped.
|
||||||
|
threshold: The threshold to use for the partition selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sparse gradient of type `tf.IndexedSlices` with the noise added.
|
||||||
|
"""
|
||||||
|
privately_selected_indices = sparse_private_partition_selection(
|
||||||
|
contribution_counts, noise_multiplier_sparse, threshold
|
||||||
|
)
|
||||||
|
noised_grad = add_sparse_gradient_noise(
|
||||||
|
grad, privately_selected_indices, noise_multiplier * l2_norm_clip
|
||||||
|
)
|
||||||
|
return noised_grad
|
||||||
|
|
|
@ -284,6 +284,158 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertGreater(batch_size, 0)
|
self.assertGreater(batch_size, 0)
|
||||||
self.assertLess(batch_size, max_index + 1)
|
self.assertLess(batch_size, max_index + 1)
|
||||||
|
|
||||||
|
def test_sparse_private_partition_selection_without_noise(self):
|
||||||
|
contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5]],
|
||||||
|
values=[2.0, 1.0, 1.0],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
noise_multiplier = 0.0
|
||||||
|
threshold = 2
|
||||||
|
sampled_indices = (
|
||||||
|
sparse_noise_utils.sparse_private_partition_selection(
|
||||||
|
contribution_counts, noise_multiplier, threshold
|
||||||
|
)
|
||||||
|
.numpy()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
expected_indices = [0]
|
||||||
|
self.assertEqual(sampled_indices, expected_indices)
|
||||||
|
|
||||||
|
def test_sparse_private_partition_selection_with_noise(self):
|
||||||
|
contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5]],
|
||||||
|
values=[50.0, 1.0, 1.0],
|
||||||
|
dense_shape=[1000],
|
||||||
|
)
|
||||||
|
noise_multiplier = 1.0
|
||||||
|
threshold = 1
|
||||||
|
sampled_indices = (
|
||||||
|
sparse_noise_utils.sparse_private_partition_selection(
|
||||||
|
contribution_counts, noise_multiplier, threshold
|
||||||
|
)
|
||||||
|
.numpy()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
expected_indices = [0]
|
||||||
|
self.assertContainsSubset(expected_indices, sampled_indices)
|
||||||
|
self.assertGreater(len(sampled_indices), 1)
|
||||||
|
|
||||||
|
def test_remap_indices(self):
|
||||||
|
expected_indices = [4, 9, 14]
|
||||||
|
indices = tf.constant([1, 5, 10], tf.int64)
|
||||||
|
skip_indices = tf.constant([0, 1, 2, 5], tf.int64)
|
||||||
|
remapped_indices = sparse_noise_utils._remap_indices(indices, skip_indices)
|
||||||
|
self.assertEqual(remapped_indices.numpy().tolist(), expected_indices)
|
||||||
|
|
||||||
|
def test_remap_indices_no_skip(self):
|
||||||
|
expected_indices = [1, 5, 10]
|
||||||
|
indices = tf.constant([1, 5, 10], tf.int64)
|
||||||
|
skip_indices = tf.constant([], tf.int64)
|
||||||
|
remapped_indices = sparse_noise_utils._remap_indices(indices, skip_indices)
|
||||||
|
self.assertEqual(remapped_indices.numpy().tolist(), expected_indices)
|
||||||
|
|
||||||
|
def test_add_sparse_gradient_noise(self):
|
||||||
|
grad = tf.IndexedSlices(
|
||||||
|
values=tf.ones((1, 2)),
|
||||||
|
indices=tf.constant([0]),
|
||||||
|
dense_shape=tf.constant([2, 2]),
|
||||||
|
)
|
||||||
|
indices = tf.constant([1], dtype=tf.int64)
|
||||||
|
noise_stddev = 1.0
|
||||||
|
noised_grad = sparse_noise_utils.add_sparse_gradient_noise(
|
||||||
|
grad, indices, noise_stddev
|
||||||
|
)
|
||||||
|
self.assertListEqual(
|
||||||
|
noised_grad.indices.numpy().tolist(), indices.numpy().tolist()
|
||||||
|
)
|
||||||
|
one_index_values = noised_grad.values[0].numpy().tolist()
|
||||||
|
self.assertNotEqual(one_index_values, [0.0, 0.0])
|
||||||
|
|
||||||
|
def test_get_contribution_counts(self):
|
||||||
|
trainable_vars = [
|
||||||
|
tf.Variable(tf.ones((1, 2)), name='var1'),
|
||||||
|
tf.Variable(tf.ones((1, 2)), name='var2'),
|
||||||
|
tf.Variable(tf.ones((1, 2)), name='var3'),
|
||||||
|
]
|
||||||
|
grads = [
|
||||||
|
tf.IndexedSlices(
|
||||||
|
values=tf.ones((1, 2)),
|
||||||
|
indices=tf.constant([0]),
|
||||||
|
dense_shape=tf.constant([2, 2]),
|
||||||
|
),
|
||||||
|
tf.ones((1, 2)),
|
||||||
|
tf.ones((1, 2)),
|
||||||
|
]
|
||||||
|
varname_to_contribution_counts_fns = {
|
||||||
|
'var1:0': [lambda grad: 1.0],
|
||||||
|
'var2:0': None,
|
||||||
|
}
|
||||||
|
contribution_counts = sparse_noise_utils.get_contribution_counts(
|
||||||
|
trainable_vars, grads, varname_to_contribution_counts_fns
|
||||||
|
)
|
||||||
|
expected_contribution_counts = [1.0, None, None]
|
||||||
|
self.assertEqual(contribution_counts, expected_contribution_counts)
|
||||||
|
|
||||||
|
def test_add_sparse_noise_without_noise(self):
|
||||||
|
grad = tf.IndexedSlices(
|
||||||
|
values=tf.ones((3, 4)),
|
||||||
|
indices=tf.constant([0, 3, 5]),
|
||||||
|
dense_shape=tf.constant([8, 4]),
|
||||||
|
)
|
||||||
|
contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5]],
|
||||||
|
values=[3.0, 1.0, 2.0],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
noised_grad = sparse_noise_utils.add_sparse_noise(
|
||||||
|
grad,
|
||||||
|
contribution_counts,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
noise_multiplier_sparse=0.0,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
threshold=1,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
noised_grad.indices.numpy().tolist(), grad.indices.numpy().tolist()
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
noised_grad.values.numpy().tolist(), grad.values.numpy().tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_add_sparse_noise_with_noise(self):
|
||||||
|
grad = tf.IndexedSlices(
|
||||||
|
values=tf.ones((3, 4)),
|
||||||
|
indices=tf.constant([0, 3, 5]),
|
||||||
|
dense_shape=tf.constant([8, 4]),
|
||||||
|
)
|
||||||
|
contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5]],
|
||||||
|
values=[10.0, 10.0, 20.0],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
noised_grad = sparse_noise_utils.add_sparse_noise(
|
||||||
|
grad,
|
||||||
|
contribution_counts,
|
||||||
|
noise_multiplier=1.0,
|
||||||
|
noise_multiplier_sparse=1.0,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
threshold=5,
|
||||||
|
)
|
||||||
|
self.assertContainsSubset(
|
||||||
|
grad.indices.numpy().tolist(),
|
||||||
|
noised_grad.indices.numpy().tolist(),
|
||||||
|
)
|
||||||
|
noised_grad_dense = tf.scatter_nd(
|
||||||
|
tf.reshape(noised_grad.indices, (-1, 1)),
|
||||||
|
noised_grad.values,
|
||||||
|
shape=(8, 4),
|
||||||
|
).numpy()
|
||||||
|
noised_grad_valid_indices = noised_grad_dense[grad.indices.numpy()]
|
||||||
|
self.assertTrue(
|
||||||
|
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue