Sparsity Preserving DP-SGD in TF Privacy [2 of 5]
Adds sparse noise utilities to privately select sparse indices from contribution counts. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 654782588
This commit is contained in:
parent
348895a7a3
commit
8747858b5b
3 changed files with 471 additions and 0 deletions
|
@ -2,6 +2,17 @@ package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "sparse_noise_utils",
|
||||||
|
srcs = ["sparse_noise_utils.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "sparse_noise_utils_test",
|
||||||
|
srcs = ["sparse_noise_utils_test.py"],
|
||||||
|
deps = [":sparse_noise_utils"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "type_aliases",
|
name = "type_aliases",
|
||||||
srcs = ["type_aliases.py"],
|
srcs = ["type_aliases.py"],
|
||||||
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright 2024, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utils for adding sparse noise to gradients.
|
||||||
|
|
||||||
|
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_probability as tfp
|
||||||
|
|
||||||
|
|
||||||
|
def split_noise_multiplier(
|
||||||
|
noise_multiplier: float,
|
||||||
|
sparse_selection_ratio: float,
|
||||||
|
sparse_selection_contribution_counts: Sequence[Optional[tf.SparseTensor]],
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""Splits noise multiplier between partition selection and gradient noise.
|
||||||
|
|
||||||
|
Returns one noise multiplier for gradient noise and one noise multiplier
|
||||||
|
for each sparse partition selection layer such that composing all gaussian
|
||||||
|
mechanisms with these noise multipliers is equivalent to applying a single
|
||||||
|
gaussian mechanism with the original noise multiplier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noise_multiplier: The original noise multiplier.
|
||||||
|
sparse_selection_ratio: The ratio of partition selection noise and gradient
|
||||||
|
noise.
|
||||||
|
sparse_selection_contribution_counts: The contribution counts for each
|
||||||
|
sparse selection variable. If a sparse selection count is None, it will be
|
||||||
|
ignored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of noise multipliers for sparse selection and gradient noise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the sparse selection ratio is not between 0 and 1, if the
|
||||||
|
sparse selection contribution counts is None, or if there are no sparse
|
||||||
|
selection contribution counts.
|
||||||
|
"""
|
||||||
|
if sparse_selection_ratio <= 0.0 or sparse_selection_ratio >= 1.0:
|
||||||
|
raise ValueError('Sparse selection ratio must be between 0 and 1.')
|
||||||
|
num_sparse_selections = sum(
|
||||||
|
1 for c in sparse_selection_contribution_counts if c is not None
|
||||||
|
)
|
||||||
|
if num_sparse_selections == 0:
|
||||||
|
raise ValueError('No sparse selections contribution counts found.')
|
||||||
|
|
||||||
|
ratio = (1.0 + sparse_selection_ratio**2.0) ** 0.5
|
||||||
|
total_noise_multiplier_sparse = noise_multiplier * ratio
|
||||||
|
noise_multiplier_partition_selection = (
|
||||||
|
num_sparse_selections**0.5 * total_noise_multiplier_sparse
|
||||||
|
)
|
||||||
|
noise_multiplier_gradient_noise = (
|
||||||
|
noise_multiplier * ratio / sparse_selection_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
return noise_multiplier_partition_selection, noise_multiplier_gradient_noise
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_sparse_indices_batch_size_heuristic(
|
||||||
|
max_index: tf.Tensor,
|
||||||
|
probability: float,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Returns a batch size using a rough heuristic to use for sampling.
|
||||||
|
|
||||||
|
This heuristic should roughly allow for the sampling to only use a single
|
||||||
|
batch to sample all indices >95% of the time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_index: The maximum index to sample.
|
||||||
|
probability: The probability of sampling each index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The batch size to use for sampling.
|
||||||
|
"""
|
||||||
|
max_num_samples = tf.cast(max_index + 1, tf.float32)
|
||||||
|
expected_num_samples = max_num_samples * probability
|
||||||
|
# For expected samples > 50, choosing a batch size of 1.2 * expected samples
|
||||||
|
# will allow for sampling only once to get all indices >95% of the time.
|
||||||
|
min_batch_size = 50.0
|
||||||
|
return tf.cast(
|
||||||
|
tf.maximum(min_batch_size, 1.2 * expected_num_samples), tf.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def sample_false_positive_indices(
|
||||||
|
max_index: tf.Tensor, probability: float, batch_size: Optional[int] = None
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Samples indices with probability `probability` iid sparsely.
|
||||||
|
|
||||||
|
This function generates a list of indices in the range of [0, max_index]
|
||||||
|
where each index is sampled with probability `probability` independently. To
|
||||||
|
achieve this efficiently, we use the geometric distribution to sample a batch
|
||||||
|
of indices at a time and repeat this process until all indices are sampled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_index: The maximum index to sample.
|
||||||
|
probability: The probability of sampling each index.
|
||||||
|
batch_size: The batch size to use for sampling. If None, a heuristic will be
|
||||||
|
used to determine the batch size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of sampled indices.
|
||||||
|
"""
|
||||||
|
if probability <= 0.0:
|
||||||
|
return tf.constant([], dtype=tf.int64)
|
||||||
|
|
||||||
|
sampled_indices = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
|
||||||
|
|
||||||
|
batch_size = batch_size or _sample_sparse_indices_batch_size_heuristic(
|
||||||
|
max_index, probability
|
||||||
|
)
|
||||||
|
|
||||||
|
geom = tfp.distributions.geometric.Geometric(probs=probability)
|
||||||
|
|
||||||
|
i, current_max = tf.constant(0), tf.constant(-1)
|
||||||
|
while current_max < max_index:
|
||||||
|
sample = tf.cast(geom.sample(batch_size) + 1, tf.int32)
|
||||||
|
indices = current_max + tf.cumsum(sample)
|
||||||
|
current_max = indices[-1]
|
||||||
|
sampled_indices = sampled_indices.write(i, indices)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
indices = tf.cast(sampled_indices.concat(), tf.int32)
|
||||||
|
indices = indices[indices <= max_index]
|
||||||
|
return tf.cast(indices, tf.int64)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_true_positive_indices(
|
||||||
|
contribution_counts: tf.SparseTensor,
|
||||||
|
noise_multiplier: float,
|
||||||
|
threshold: int,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Samples indices where the count + Gaussian noise is above a threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contribution_counts: The contribution counts for each index.
|
||||||
|
noise_multiplier: The noise multiplier to use for the gaussian noise.
|
||||||
|
threshold: The threshold to use for the selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of sampled indices.
|
||||||
|
"""
|
||||||
|
contribution_count_values = tf.reshape(contribution_counts.values, (-1,))
|
||||||
|
noised_contribution_count_values = (
|
||||||
|
contribution_count_values
|
||||||
|
+ tf.random.normal(
|
||||||
|
tf.shape(contribution_count_values),
|
||||||
|
mean=0.0,
|
||||||
|
stddev=noise_multiplier,
|
||||||
|
dtype=tf.float32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
noised_contribution_counts_indices = contribution_counts.indices[
|
||||||
|
noised_contribution_count_values >= threshold
|
||||||
|
][:, 0]
|
||||||
|
return tf.reshape(noised_contribution_counts_indices, (-1,))
|
|
@ -0,0 +1,289 @@
|
||||||
|
# Copyright 2024, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for sparse_noise_utils."""
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||||
|
|
||||||
|
|
||||||
|
class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='one_sparse_layer',
|
||||||
|
noise_multiplier=1.0,
|
||||||
|
sparse_selection_ratio=0.8,
|
||||||
|
sparse_selection_contribution_counts=[
|
||||||
|
tf.SparseTensor(
|
||||||
|
indices=[[0]],
|
||||||
|
values=[1],
|
||||||
|
dense_shape=[3],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='multiple_sparse_layer',
|
||||||
|
noise_multiplier=1.0,
|
||||||
|
sparse_selection_ratio=0.1,
|
||||||
|
sparse_selection_contribution_counts=[
|
||||||
|
tf.SparseTensor(
|
||||||
|
indices=[[0]],
|
||||||
|
values=[1],
|
||||||
|
dense_shape=[3],
|
||||||
|
),
|
||||||
|
tf.SparseTensor(
|
||||||
|
indices=[[0]],
|
||||||
|
values=[1],
|
||||||
|
dense_shape=[3],
|
||||||
|
),
|
||||||
|
tf.SparseTensor(
|
||||||
|
indices=[[0]],
|
||||||
|
values=[1],
|
||||||
|
dense_shape=[3],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_split_noise_multiplier(
|
||||||
|
self,
|
||||||
|
noise_multiplier,
|
||||||
|
sparse_selection_ratio,
|
||||||
|
sparse_selection_contribution_counts,
|
||||||
|
):
|
||||||
|
noise_multiplier_sparse, noise_multiplier_dense = (
|
||||||
|
sparse_noise_utils.split_noise_multiplier(
|
||||||
|
noise_multiplier,
|
||||||
|
sparse_selection_ratio,
|
||||||
|
sparse_selection_contribution_counts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
num_sparse_layers = len(sparse_selection_contribution_counts)
|
||||||
|
|
||||||
|
total_noise_multiplier_sparse = (
|
||||||
|
noise_multiplier_sparse / num_sparse_layers**0.5
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
total_noise_multiplier_sparse,
|
||||||
|
sparse_selection_ratio * noise_multiplier_dense,
|
||||||
|
)
|
||||||
|
total_noise_multiplier = (
|
||||||
|
1.0
|
||||||
|
/ (
|
||||||
|
1.0 / total_noise_multiplier_sparse**2
|
||||||
|
+ 1.0 / noise_multiplier_dense**2
|
||||||
|
)
|
||||||
|
** 0.5
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(total_noise_multiplier, noise_multiplier)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='no_sparse_layers',
|
||||||
|
noise_multiplier=1.0,
|
||||||
|
sparse_selection_ratio=0.5,
|
||||||
|
sparse_selection_contribution_counts=[],
|
||||||
|
error_message='No sparse selections contribution counts found.',
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='sparse_layers_none',
|
||||||
|
noise_multiplier=1.0,
|
||||||
|
sparse_selection_ratio=0.5,
|
||||||
|
sparse_selection_contribution_counts=[None],
|
||||||
|
error_message='No sparse selections contribution counts found.',
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='zero_ratio',
|
||||||
|
noise_multiplier=1.0,
|
||||||
|
sparse_selection_ratio=0.0,
|
||||||
|
sparse_selection_contribution_counts=[
|
||||||
|
tf.SparseTensor(
|
||||||
|
indices=[[0]],
|
||||||
|
values=[1],
|
||||||
|
dense_shape=[3],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
error_message='Sparse selection ratio must be between 0 and 1.',
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='one_ratio',
|
||||||
|
noise_multiplier=1.0,
|
||||||
|
sparse_selection_ratio=1.0,
|
||||||
|
sparse_selection_contribution_counts=[
|
||||||
|
tf.SparseTensor(
|
||||||
|
indices=[[0]],
|
||||||
|
values=[1],
|
||||||
|
dense_shape=[3],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
error_message='Sparse selection ratio must be between 0 and 1.',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_split_noise_multiplier_errors(
|
||||||
|
self,
|
||||||
|
noise_multiplier,
|
||||||
|
sparse_selection_ratio,
|
||||||
|
sparse_selection_contribution_counts,
|
||||||
|
error_message,
|
||||||
|
):
|
||||||
|
with self.assertRaisesRegex(ValueError, error_message):
|
||||||
|
sparse_noise_utils.split_noise_multiplier(
|
||||||
|
noise_multiplier,
|
||||||
|
sparse_selection_ratio,
|
||||||
|
sparse_selection_contribution_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='max_index_0',
|
||||||
|
max_index=0,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='max_index_10',
|
||||||
|
max_index=10,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_sample_false_positive_indices_one_prob(self, max_index):
|
||||||
|
sampled_indices = (
|
||||||
|
sparse_noise_utils.sample_false_positive_indices(max_index, 1.0)
|
||||||
|
.numpy()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
expected_indices = list(range(max_index + 1))
|
||||||
|
self.assertEqual(sampled_indices, expected_indices)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='max_index_0',
|
||||||
|
max_index=0,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='max_index_10',
|
||||||
|
max_index=10,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_sample_false_positive_indices_zero_prob(self, max_index):
|
||||||
|
sampled_indices = (
|
||||||
|
sparse_noise_utils.sample_false_positive_indices(max_index, 0.0)
|
||||||
|
.numpy()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
self.assertEmpty(sampled_indices)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='max_index_10_prob_50',
|
||||||
|
prob=0.5,
|
||||||
|
max_index=10,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='max_index_20_prob_25',
|
||||||
|
prob=0.25,
|
||||||
|
max_index=20,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='max_index_20_prob_75',
|
||||||
|
prob=0.75,
|
||||||
|
max_index=20,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_sample_false_positive_indices_random(self, max_index, prob):
|
||||||
|
sampled_indices = sparse_noise_utils.sample_false_positive_indices(
|
||||||
|
max_index, prob
|
||||||
|
)
|
||||||
|
sampled_indices = sampled_indices.numpy()
|
||||||
|
|
||||||
|
self.assertLessEqual(np.max(sampled_indices), max_index)
|
||||||
|
self.assertGreaterEqual(np.min(sampled_indices), 0)
|
||||||
|
|
||||||
|
self.assertGreater(
|
||||||
|
stats.binomtest(k=len(sampled_indices), n=max_index, p=prob).pvalue,
|
||||||
|
1e-10,
|
||||||
|
)
|
||||||
|
|
||||||
|
bins = np.arange(max_index + 1) + 1
|
||||||
|
histogram, _ = np.histogram(sampled_indices, bins=bins)
|
||||||
|
|
||||||
|
num_trials = 10000
|
||||||
|
for _ in range(num_trials):
|
||||||
|
sampled_indices = sparse_noise_utils.sample_false_positive_indices(
|
||||||
|
max_index, prob
|
||||||
|
).numpy()
|
||||||
|
histogram += np.histogram(sampled_indices, bins=bins)[0]
|
||||||
|
|
||||||
|
min_pvalue = min(
|
||||||
|
stats.binomtest(k=h.item(), n=num_trials, p=prob).pvalue
|
||||||
|
for h in histogram
|
||||||
|
)
|
||||||
|
self.assertGreater(min_pvalue, 1e-10)
|
||||||
|
|
||||||
|
def test_sample_true_positive_indices_empty(self):
|
||||||
|
contribution_counts = tf.SparseTensor(
|
||||||
|
indices=np.zeros((0, 1), dtype=np.int64),
|
||||||
|
values=[],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
noise_multiplier = 10.0
|
||||||
|
threshold = 2
|
||||||
|
sampled_indices = sparse_noise_utils.sample_true_positive_indices(
|
||||||
|
contribution_counts, noise_multiplier, threshold
|
||||||
|
)
|
||||||
|
sampled_indices = list(sampled_indices.numpy())
|
||||||
|
expected_indices = []
|
||||||
|
self.assertEqual(sampled_indices, expected_indices)
|
||||||
|
|
||||||
|
def test_sample_true_positive_indices_without_noise(self):
|
||||||
|
contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5], [7]],
|
||||||
|
values=[3.0, 1.0, 1.0, 2.0],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
noise_multiplier = 0.0
|
||||||
|
threshold = 2
|
||||||
|
sampled_indices = sparse_noise_utils.sample_true_positive_indices(
|
||||||
|
contribution_counts, noise_multiplier, threshold
|
||||||
|
)
|
||||||
|
sampled_indices = list(sampled_indices.numpy())
|
||||||
|
expected_indices = [0, 7]
|
||||||
|
self.assertEqual(sampled_indices, expected_indices)
|
||||||
|
|
||||||
|
def test_sample_true_positive_indices_with_noise(self):
|
||||||
|
contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5], [7]],
|
||||||
|
values=[30.0, 1.0, 1.0, 20.0],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
noise_multiplier = 1.0
|
||||||
|
threshold = 10
|
||||||
|
sampled_indices = sparse_noise_utils.sample_true_positive_indices(
|
||||||
|
contribution_counts, noise_multiplier, threshold
|
||||||
|
)
|
||||||
|
sampled_indices = list(sampled_indices.numpy())
|
||||||
|
expected_indices = [0, 7]
|
||||||
|
self.assertEqual(sampled_indices, expected_indices)
|
||||||
|
|
||||||
|
def test_batch_size_heuristic(self):
|
||||||
|
max_index = 100
|
||||||
|
prob = 0.5
|
||||||
|
batch_size = sparse_noise_utils._sample_sparse_indices_batch_size_heuristic(
|
||||||
|
max_index, prob
|
||||||
|
)
|
||||||
|
self.assertGreater(batch_size, 0)
|
||||||
|
self.assertLess(batch_size, max_index + 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue