diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py index dd98a45..839a559 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py @@ -16,8 +16,9 @@ For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357. """ -from typing import Optional, Sequence +from typing import Mapping, Optional, Sequence +from scipy import stats import tensorflow as tf import tensorflow_probability as tfp @@ -169,3 +170,191 @@ def sample_true_positive_indices( noised_contribution_count_values >= threshold ][:, 0] return tf.reshape(noised_contribution_counts_indices, (-1,)) + + +@tf.function +def _remap_indices( + indices: tf.Tensor, + skip_indices: tf.Tensor, +) -> tf.Tensor: + """Remaps the indices while skipping the skip indices. + + As an example, if skip_indices = [1, 3], then the indices will be remapped as + follows: + 0 -> 0 + 1 -> 2 + 2 -> 4 + 3 -> 5 + 4 -> 6 + 5 -> 7 + ... + + This is useful for merging the true positive and false positive indices. + + Args: + indices: The indices to remap. + skip_indices: The indices to skip while remapping. Assumed to be sorted. + + Returns: + The remapped indices. + """ + + def piecewise_map(skip_indices): + map_counts = tf.range(tf.size(skip_indices) + 1, dtype=tf.int64) + map_idx = skip_indices - map_counts[:-1] + skip_indices = tf.concat([[-1], skip_indices], axis=0) + gaps = skip_indices[1:] - skip_indices[:-1] - 1 + + map_idx = tf.concat([map_idx, [tf.int64.max]], axis=0) + gaps = tf.concat([gaps, [1]], axis=0) + + return map_idx[gaps > 0], map_counts[gaps > 0] + + map_idx, map_count = piecewise_map(skip_indices) + idx = tf.searchsorted(map_idx, indices, side='right') + offset = tf.gather(map_count, idx) + return indices + offset + + +def sparse_private_partition_selection( + contribution_counts: tf.SparseTensor, + noise_multiplier: float, + threshold: int, +) -> tf.Tensor: + """Differentially private partition selection. + + Uses the sparse sampling algorithm to sample false positive indices. Also + assumes that the contribution counts are clipped to a per example contribution + of 1. + + Args: + contribution_counts: The contribution counts for each index. + noise_multiplier: The noise multiplier to use for the gaussian noise. + threshold: The threshold to use for the selection. + + Returns: + A tensor of selected indices. + """ + if threshold < 0: + raise ValueError(f'Threshold must be positive, got {threshold}.') + + true_positive_indices = sample_true_positive_indices( + contribution_counts, noise_multiplier, threshold + ) + + if noise_multiplier <= 0.0: + return true_positive_indices + + # probability of selecting an index with zero contribution count. + prob = stats.norm.sf(threshold / noise_multiplier).item() + + num_total_indices = tf.cast(contribution_counts.dense_shape[0], tf.int32) + num_non_zero_indices = tf.shape(contribution_counts.values)[0] + max_index = tf.cast(num_total_indices - num_non_zero_indices - 1, tf.int32) + false_positive_indices = sample_false_positive_indices(max_index, prob) + remapped_false_positive_indices = _remap_indices( + false_positive_indices, tf.reshape(contribution_counts.indices, (-1,)) + ) + merged_indices = tf.sort( + tf.concat( + [remapped_false_positive_indices, true_positive_indices], axis=0 + ) + ) + return merged_indices + + +def add_sparse_gradient_noise( + grad: tf.IndexedSlices, indices: tf.Tensor, noise_stddev: float +) -> tf.IndexedSlices: + """Adds sparse gradient noise. + + Args: + grad: A sparse gradient of type `tf.IndexedSlices`. + indices: The selected indices to keep. + noise_stddev: The standard deviation of the noise to add. + + Returns: + A sparse gradient of type `tf.IndexedSlices` with the noise added. + """ + filtered_grad_values = tf.gather(grad, indices) + sparse_noise_values = tf.random.normal( + filtered_grad_values.shape, mean=0.0, stddev=noise_stddev + ) + filtered_noised_grad_values = filtered_grad_values + sparse_noise_values + return tf.IndexedSlices( + indices=indices, + values=filtered_noised_grad_values, + dense_shape=grad.dense_shape, + ) + + +def get_contribution_counts( + trainable_vars: list[tf.Variable], + grads: list[tf.Tensor], + varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor], +) -> list[tf.Tensor | None]: + """Gets the contribution counts for each variable in the Model. + + Args: + trainable_vars: A list of the trainable variables in the Model. + grads: A corresponding list of gradients for each trainable variable. + varname_to_contribution_counts_fns: A mapping from variable name to a list + of functions to get the contribution counts for that variable. + + Returns: + A list of contribution counts for each variable and None for variables that + do not have contribution counts function. + + Raises: + NotImplementedError: If there are more than one contribution counts function + for a variable. + """ + contribution_counts_list = [] + for var, grad in zip(trainable_vars, grads): + if var.name not in varname_to_contribution_counts_fns: + contribution_counts_list.append(None) + continue + contribution_counts_fns = varname_to_contribution_counts_fns[var.name] + if not contribution_counts_fns or not contribution_counts_fns[0]: + contribution_counts_list.append(None) + continue + if len(contribution_counts_fns) > 1: + raise NotImplementedError( + 'Sparse noise is not supported for shared weight variables.' + ) + contribution_counts_fn = contribution_counts_fns[0] + contribution_counts = contribution_counts_fn(grad) + contribution_counts_list.append(contribution_counts) + + return contribution_counts_list + + +def add_sparse_noise( + grad: tf.IndexedSlices, + contribution_counts: tf.SparseTensor, + noise_multiplier: float, + noise_multiplier_sparse: float, + l2_norm_clip: float, + threshold: int, +) -> tf.IndexedSlices: + """Adds sparse noise to a gradient. + + Args: + grad: A sparse gradient of type `tf.IndexedSlices`. + contribution_counts: The contribution counts for each index of grad. + noise_multiplier: The noise multiplier to use for the gradient noise. + noise_multiplier_sparse: The noise multiplier to use for the partition + selection. + l2_norm_clip: The l2 norm clip at which the gradient is clipped. + threshold: The threshold to use for the partition selection. + + Returns: + A sparse gradient of type `tf.IndexedSlices` with the noise added. + """ + privately_selected_indices = sparse_private_partition_selection( + contribution_counts, noise_multiplier_sparse, threshold + ) + noised_grad = add_sparse_gradient_noise( + grad, privately_selected_indices, noise_multiplier * l2_norm_clip + ) + return noised_grad diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py index 5941e74..38e11b2 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py @@ -284,6 +284,158 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(batch_size, 0) self.assertLess(batch_size, max_index + 1) + def test_sparse_private_partition_selection_without_noise(self): + contribution_counts = tf.SparseTensor( + indices=[[0], [3], [5]], + values=[2.0, 1.0, 1.0], + dense_shape=[8], + ) + noise_multiplier = 0.0 + threshold = 2 + sampled_indices = ( + sparse_noise_utils.sparse_private_partition_selection( + contribution_counts, noise_multiplier, threshold + ) + .numpy() + .tolist() + ) + expected_indices = [0] + self.assertEqual(sampled_indices, expected_indices) + + def test_sparse_private_partition_selection_with_noise(self): + contribution_counts = tf.SparseTensor( + indices=[[0], [3], [5]], + values=[50.0, 1.0, 1.0], + dense_shape=[1000], + ) + noise_multiplier = 1.0 + threshold = 1 + sampled_indices = ( + sparse_noise_utils.sparse_private_partition_selection( + contribution_counts, noise_multiplier, threshold + ) + .numpy() + .tolist() + ) + expected_indices = [0] + self.assertContainsSubset(expected_indices, sampled_indices) + self.assertGreater(len(sampled_indices), 1) + + def test_remap_indices(self): + expected_indices = [4, 9, 14] + indices = tf.constant([1, 5, 10], tf.int64) + skip_indices = tf.constant([0, 1, 2, 5], tf.int64) + remapped_indices = sparse_noise_utils._remap_indices(indices, skip_indices) + self.assertEqual(remapped_indices.numpy().tolist(), expected_indices) + + def test_remap_indices_no_skip(self): + expected_indices = [1, 5, 10] + indices = tf.constant([1, 5, 10], tf.int64) + skip_indices = tf.constant([], tf.int64) + remapped_indices = sparse_noise_utils._remap_indices(indices, skip_indices) + self.assertEqual(remapped_indices.numpy().tolist(), expected_indices) + + def test_add_sparse_gradient_noise(self): + grad = tf.IndexedSlices( + values=tf.ones((1, 2)), + indices=tf.constant([0]), + dense_shape=tf.constant([2, 2]), + ) + indices = tf.constant([1], dtype=tf.int64) + noise_stddev = 1.0 + noised_grad = sparse_noise_utils.add_sparse_gradient_noise( + grad, indices, noise_stddev + ) + self.assertListEqual( + noised_grad.indices.numpy().tolist(), indices.numpy().tolist() + ) + one_index_values = noised_grad.values[0].numpy().tolist() + self.assertNotEqual(one_index_values, [0.0, 0.0]) + + def test_get_contribution_counts(self): + trainable_vars = [ + tf.Variable(tf.ones((1, 2)), name='var1'), + tf.Variable(tf.ones((1, 2)), name='var2'), + tf.Variable(tf.ones((1, 2)), name='var3'), + ] + grads = [ + tf.IndexedSlices( + values=tf.ones((1, 2)), + indices=tf.constant([0]), + dense_shape=tf.constant([2, 2]), + ), + tf.ones((1, 2)), + tf.ones((1, 2)), + ] + varname_to_contribution_counts_fns = { + 'var1:0': [lambda grad: 1.0], + 'var2:0': None, + } + contribution_counts = sparse_noise_utils.get_contribution_counts( + trainable_vars, grads, varname_to_contribution_counts_fns + ) + expected_contribution_counts = [1.0, None, None] + self.assertEqual(contribution_counts, expected_contribution_counts) + + def test_add_sparse_noise_without_noise(self): + grad = tf.IndexedSlices( + values=tf.ones((3, 4)), + indices=tf.constant([0, 3, 5]), + dense_shape=tf.constant([8, 4]), + ) + contribution_counts = tf.SparseTensor( + indices=[[0], [3], [5]], + values=[3.0, 1.0, 2.0], + dense_shape=[8], + ) + noised_grad = sparse_noise_utils.add_sparse_noise( + grad, + contribution_counts, + noise_multiplier=0.0, + noise_multiplier_sparse=0.0, + l2_norm_clip=1.0, + threshold=1, + ) + self.assertEqual( + noised_grad.indices.numpy().tolist(), grad.indices.numpy().tolist() + ) + self.assertEqual( + noised_grad.values.numpy().tolist(), grad.values.numpy().tolist() + ) + + def test_add_sparse_noise_with_noise(self): + grad = tf.IndexedSlices( + values=tf.ones((3, 4)), + indices=tf.constant([0, 3, 5]), + dense_shape=tf.constant([8, 4]), + ) + contribution_counts = tf.SparseTensor( + indices=[[0], [3], [5]], + values=[10.0, 10.0, 20.0], + dense_shape=[8], + ) + noised_grad = sparse_noise_utils.add_sparse_noise( + grad, + contribution_counts, + noise_multiplier=1.0, + noise_multiplier_sparse=1.0, + l2_norm_clip=1.0, + threshold=5, + ) + self.assertContainsSubset( + grad.indices.numpy().tolist(), + noised_grad.indices.numpy().tolist(), + ) + noised_grad_dense = tf.scatter_nd( + tf.reshape(noised_grad.indices, (-1, 1)), + noised_grad.values, + shape=(8, 4), + ).numpy() + noised_grad_valid_indices = noised_grad_dense[grad.indices.numpy()] + self.assertTrue( + np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy())) + ) + if __name__ == '__main__': tf.test.main()