Sparsity Preserving DP-SGD in TF Privacy [2 of 5]

Adds sparse noise utilities to privately select sparse indices from contribution counts.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 654782588
This commit is contained in:
A. Unique TensorFlower 2024-07-22 09:25:02 -07:00
parent 348895a7a3
commit 8747858b5b
3 changed files with 471 additions and 0 deletions

View file

@ -2,6 +2,17 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"])
py_library(
name = "sparse_noise_utils",
srcs = ["sparse_noise_utils.py"],
)
py_test(
name = "sparse_noise_utils_test",
srcs = ["sparse_noise_utils_test.py"],
deps = [":sparse_noise_utils"],
)
py_library(
name = "type_aliases",
srcs = ["type_aliases.py"],

View file

@ -0,0 +1,171 @@
# Copyright 2024, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for adding sparse noise to gradients.
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
"""
from typing import Optional, Sequence
import tensorflow as tf
import tensorflow_probability as tfp
def split_noise_multiplier(
noise_multiplier: float,
sparse_selection_ratio: float,
sparse_selection_contribution_counts: Sequence[Optional[tf.SparseTensor]],
) -> tuple[float, float]:
"""Splits noise multiplier between partition selection and gradient noise.
Returns one noise multiplier for gradient noise and one noise multiplier
for each sparse partition selection layer such that composing all gaussian
mechanisms with these noise multipliers is equivalent to applying a single
gaussian mechanism with the original noise multiplier.
Args:
noise_multiplier: The original noise multiplier.
sparse_selection_ratio: The ratio of partition selection noise and gradient
noise.
sparse_selection_contribution_counts: The contribution counts for each
sparse selection variable. If a sparse selection count is None, it will be
ignored.
Returns:
A tuple of noise multipliers for sparse selection and gradient noise.
Raises:
ValueError: If the sparse selection ratio is not between 0 and 1, if the
sparse selection contribution counts is None, or if there are no sparse
selection contribution counts.
"""
if sparse_selection_ratio <= 0.0 or sparse_selection_ratio >= 1.0:
raise ValueError('Sparse selection ratio must be between 0 and 1.')
num_sparse_selections = sum(
1 for c in sparse_selection_contribution_counts if c is not None
)
if num_sparse_selections == 0:
raise ValueError('No sparse selections contribution counts found.')
ratio = (1.0 + sparse_selection_ratio**2.0) ** 0.5
total_noise_multiplier_sparse = noise_multiplier * ratio
noise_multiplier_partition_selection = (
num_sparse_selections**0.5 * total_noise_multiplier_sparse
)
noise_multiplier_gradient_noise = (
noise_multiplier * ratio / sparse_selection_ratio
)
return noise_multiplier_partition_selection, noise_multiplier_gradient_noise
def _sample_sparse_indices_batch_size_heuristic(
max_index: tf.Tensor,
probability: float,
) -> tf.Tensor:
"""Returns a batch size using a rough heuristic to use for sampling.
This heuristic should roughly allow for the sampling to only use a single
batch to sample all indices >95% of the time.
Args:
max_index: The maximum index to sample.
probability: The probability of sampling each index.
Returns:
The batch size to use for sampling.
"""
max_num_samples = tf.cast(max_index + 1, tf.float32)
expected_num_samples = max_num_samples * probability
# For expected samples > 50, choosing a batch size of 1.2 * expected samples
# will allow for sampling only once to get all indices >95% of the time.
min_batch_size = 50.0
return tf.cast(
tf.maximum(min_batch_size, 1.2 * expected_num_samples), tf.int32
)
@tf.function
def sample_false_positive_indices(
max_index: tf.Tensor, probability: float, batch_size: Optional[int] = None
) -> tf.Tensor:
"""Samples indices with probability `probability` iid sparsely.
This function generates a list of indices in the range of [0, max_index]
where each index is sampled with probability `probability` independently. To
achieve this efficiently, we use the geometric distribution to sample a batch
of indices at a time and repeat this process until all indices are sampled.
Args:
max_index: The maximum index to sample.
probability: The probability of sampling each index.
batch_size: The batch size to use for sampling. If None, a heuristic will be
used to determine the batch size.
Returns:
A tensor of sampled indices.
"""
if probability <= 0.0:
return tf.constant([], dtype=tf.int64)
sampled_indices = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
batch_size = batch_size or _sample_sparse_indices_batch_size_heuristic(
max_index, probability
)
geom = tfp.distributions.geometric.Geometric(probs=probability)
i, current_max = tf.constant(0), tf.constant(-1)
while current_max < max_index:
sample = tf.cast(geom.sample(batch_size) + 1, tf.int32)
indices = current_max + tf.cumsum(sample)
current_max = indices[-1]
sampled_indices = sampled_indices.write(i, indices)
i += 1
indices = tf.cast(sampled_indices.concat(), tf.int32)
indices = indices[indices <= max_index]
return tf.cast(indices, tf.int64)
def sample_true_positive_indices(
contribution_counts: tf.SparseTensor,
noise_multiplier: float,
threshold: int,
) -> tf.Tensor:
"""Samples indices where the count + Gaussian noise is above a threshold.
Args:
contribution_counts: The contribution counts for each index.
noise_multiplier: The noise multiplier to use for the gaussian noise.
threshold: The threshold to use for the selection.
Returns:
A tensor of sampled indices.
"""
contribution_count_values = tf.reshape(contribution_counts.values, (-1,))
noised_contribution_count_values = (
contribution_count_values
+ tf.random.normal(
tf.shape(contribution_count_values),
mean=0.0,
stddev=noise_multiplier,
dtype=tf.float32,
)
)
noised_contribution_counts_indices = contribution_counts.indices[
noised_contribution_count_values >= threshold
][:, 0]
return tf.reshape(noised_contribution_counts_indices, (-1,))

View file

@ -0,0 +1,289 @@
# Copyright 2024, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for sparse_noise_utils."""
from absl.testing import parameterized
import numpy as np
from scipy import stats
import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='one_sparse_layer',
noise_multiplier=1.0,
sparse_selection_ratio=0.8,
sparse_selection_contribution_counts=[
tf.SparseTensor(
indices=[[0]],
values=[1],
dense_shape=[3],
)
],
),
dict(
testcase_name='multiple_sparse_layer',
noise_multiplier=1.0,
sparse_selection_ratio=0.1,
sparse_selection_contribution_counts=[
tf.SparseTensor(
indices=[[0]],
values=[1],
dense_shape=[3],
),
tf.SparseTensor(
indices=[[0]],
values=[1],
dense_shape=[3],
),
tf.SparseTensor(
indices=[[0]],
values=[1],
dense_shape=[3],
),
],
),
)
def test_split_noise_multiplier(
self,
noise_multiplier,
sparse_selection_ratio,
sparse_selection_contribution_counts,
):
noise_multiplier_sparse, noise_multiplier_dense = (
sparse_noise_utils.split_noise_multiplier(
noise_multiplier,
sparse_selection_ratio,
sparse_selection_contribution_counts,
)
)
num_sparse_layers = len(sparse_selection_contribution_counts)
total_noise_multiplier_sparse = (
noise_multiplier_sparse / num_sparse_layers**0.5
)
self.assertAlmostEqual(
total_noise_multiplier_sparse,
sparse_selection_ratio * noise_multiplier_dense,
)
total_noise_multiplier = (
1.0
/ (
1.0 / total_noise_multiplier_sparse**2
+ 1.0 / noise_multiplier_dense**2
)
** 0.5
)
self.assertAlmostEqual(total_noise_multiplier, noise_multiplier)
@parameterized.named_parameters(
dict(
testcase_name='no_sparse_layers',
noise_multiplier=1.0,
sparse_selection_ratio=0.5,
sparse_selection_contribution_counts=[],
error_message='No sparse selections contribution counts found.',
),
dict(
testcase_name='sparse_layers_none',
noise_multiplier=1.0,
sparse_selection_ratio=0.5,
sparse_selection_contribution_counts=[None],
error_message='No sparse selections contribution counts found.',
),
dict(
testcase_name='zero_ratio',
noise_multiplier=1.0,
sparse_selection_ratio=0.0,
sparse_selection_contribution_counts=[
tf.SparseTensor(
indices=[[0]],
values=[1],
dense_shape=[3],
)
],
error_message='Sparse selection ratio must be between 0 and 1.',
),
dict(
testcase_name='one_ratio',
noise_multiplier=1.0,
sparse_selection_ratio=1.0,
sparse_selection_contribution_counts=[
tf.SparseTensor(
indices=[[0]],
values=[1],
dense_shape=[3],
)
],
error_message='Sparse selection ratio must be between 0 and 1.',
),
)
def test_split_noise_multiplier_errors(
self,
noise_multiplier,
sparse_selection_ratio,
sparse_selection_contribution_counts,
error_message,
):
with self.assertRaisesRegex(ValueError, error_message):
sparse_noise_utils.split_noise_multiplier(
noise_multiplier,
sparse_selection_ratio,
sparse_selection_contribution_counts,
)
@parameterized.named_parameters(
dict(
testcase_name='max_index_0',
max_index=0,
),
dict(
testcase_name='max_index_10',
max_index=10,
),
)
def test_sample_false_positive_indices_one_prob(self, max_index):
sampled_indices = (
sparse_noise_utils.sample_false_positive_indices(max_index, 1.0)
.numpy()
.tolist()
)
expected_indices = list(range(max_index + 1))
self.assertEqual(sampled_indices, expected_indices)
@parameterized.named_parameters(
dict(
testcase_name='max_index_0',
max_index=0,
),
dict(
testcase_name='max_index_10',
max_index=10,
),
)
def test_sample_false_positive_indices_zero_prob(self, max_index):
sampled_indices = (
sparse_noise_utils.sample_false_positive_indices(max_index, 0.0)
.numpy()
.tolist()
)
self.assertEmpty(sampled_indices)
@parameterized.named_parameters(
dict(
testcase_name='max_index_10_prob_50',
prob=0.5,
max_index=10,
),
dict(
testcase_name='max_index_20_prob_25',
prob=0.25,
max_index=20,
),
dict(
testcase_name='max_index_20_prob_75',
prob=0.75,
max_index=20,
),
)
def test_sample_false_positive_indices_random(self, max_index, prob):
sampled_indices = sparse_noise_utils.sample_false_positive_indices(
max_index, prob
)
sampled_indices = sampled_indices.numpy()
self.assertLessEqual(np.max(sampled_indices), max_index)
self.assertGreaterEqual(np.min(sampled_indices), 0)
self.assertGreater(
stats.binomtest(k=len(sampled_indices), n=max_index, p=prob).pvalue,
1e-10,
)
bins = np.arange(max_index + 1) + 1
histogram, _ = np.histogram(sampled_indices, bins=bins)
num_trials = 10000
for _ in range(num_trials):
sampled_indices = sparse_noise_utils.sample_false_positive_indices(
max_index, prob
).numpy()
histogram += np.histogram(sampled_indices, bins=bins)[0]
min_pvalue = min(
stats.binomtest(k=h.item(), n=num_trials, p=prob).pvalue
for h in histogram
)
self.assertGreater(min_pvalue, 1e-10)
def test_sample_true_positive_indices_empty(self):
contribution_counts = tf.SparseTensor(
indices=np.zeros((0, 1), dtype=np.int64),
values=[],
dense_shape=[8],
)
noise_multiplier = 10.0
threshold = 2
sampled_indices = sparse_noise_utils.sample_true_positive_indices(
contribution_counts, noise_multiplier, threshold
)
sampled_indices = list(sampled_indices.numpy())
expected_indices = []
self.assertEqual(sampled_indices, expected_indices)
def test_sample_true_positive_indices_without_noise(self):
contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5], [7]],
values=[3.0, 1.0, 1.0, 2.0],
dense_shape=[8],
)
noise_multiplier = 0.0
threshold = 2
sampled_indices = sparse_noise_utils.sample_true_positive_indices(
contribution_counts, noise_multiplier, threshold
)
sampled_indices = list(sampled_indices.numpy())
expected_indices = [0, 7]
self.assertEqual(sampled_indices, expected_indices)
def test_sample_true_positive_indices_with_noise(self):
contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5], [7]],
values=[30.0, 1.0, 1.0, 20.0],
dense_shape=[8],
)
noise_multiplier = 1.0
threshold = 10
sampled_indices = sparse_noise_utils.sample_true_positive_indices(
contribution_counts, noise_multiplier, threshold
)
sampled_indices = list(sampled_indices.numpy())
expected_indices = [0, 7]
self.assertEqual(sampled_indices, expected_indices)
def test_batch_size_heuristic(self):
max_index = 100
prob = 0.5
batch_size = sparse_noise_utils._sample_sparse_indices_batch_size_heuristic(
max_index, prob
)
self.assertGreater(batch_size, 0)
self.assertLess(batch_size, max_index + 1)
if __name__ == '__main__':
tf.test.main()