Sparsity Preserving DP-SGD in TF Privacy

Add function to merge varname_to_contribution_count_fn maps from different layers.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 664906202
This commit is contained in:
A. Unique TensorFlower 2024-08-19 11:43:50 -07:00
parent 38d80cae92
commit 93c7e54327
4 changed files with 153 additions and 7 deletions

View file

@ -5,12 +5,19 @@ licenses(["notice"])
py_library( py_library(
name = "sparse_noise_utils", name = "sparse_noise_utils",
srcs = ["sparse_noise_utils.py"], srcs = ["sparse_noise_utils.py"],
deps = [
":type_aliases",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
) )
py_test( py_test(
name = "sparse_noise_utils_test", name = "sparse_noise_utils_test",
srcs = ["sparse_noise_utils_test.py"], srcs = ["sparse_noise_utils_test.py"],
deps = [":sparse_noise_utils"], deps = [
":sparse_noise_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
) )
py_library( py_library(

View file

@ -16,10 +16,13 @@
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357. For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
""" """
import collections
from typing import Mapping, Optional, Sequence from typing import Mapping, Optional, Sequence
from scipy import stats from scipy import stats
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
import tensorflow_probability as tfp import tensorflow_probability as tfp
@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
) )
def extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
trainable_vars: Sequence[tf.Variable],
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
"""Extracts a map of contribution count fns from generator outputs.
Args:
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
instances returned by
`gradient_clipping_utils.model_forward_backward_pass`.
trainable_vars: A list of trainable variables.
Returns:
A `dict` from varname to contribution counts functions
"""
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
varname_to_contribution_counts_fns = collections.defaultdict(list)
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
if registry_fn_output.varname_to_count_contribution_fn is not None:
duplicate_varnames = set(
registry_fn_output.varname_to_count_contribution_fn.keys()
) & set(varname_to_contribution_counts_fns.keys())
if duplicate_varnames:
raise ValueError(
'Duplicate varnames: {duplicate_varnames} found in contribution'
' counts functions.'
)
varname_to_contribution_counts_fns.update(
registry_fn_output.varname_to_count_contribution_fn
)
return varname_to_contribution_counts_fns
def get_contribution_counts( def get_contribution_counts(
trainable_vars: list[tf.Variable], trainable_vars: Sequence[tf.Variable],
grads: list[tf.Tensor], grads: Sequence[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor], varname_to_contribution_counts_fns: Mapping[
) -> list[tf.Tensor | None]: str, type_aliases.ContributionCountHistogramFn
],
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
"""Gets the contribution counts for each variable in the Model. """Gets the contribution counts for each variable in the Model.
Args: Args:
trainable_vars: A list of the trainable variables in the Model. trainable_vars: A list of trainable variables.
grads: A corresponding list of gradients for each trainable variable. grads: A corresponding list of gradients for each trainable variable.
varname_to_contribution_counts_fns: A mapping from variable name to a list varname_to_contribution_counts_fns: A mapping from variable name to a list
of functions to get the contribution counts for that variable. of functions to get the contribution counts for that variable.

View file

@ -17,6 +17,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
@ -436,6 +437,96 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy())) np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
) )
def test_extract_varname_to_contribution_counts_fns(self):
def fn1(_):
return 1.0
def fn2(_):
return 2.0
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
var3 = tf.Variable(tf.ones((1, 2)), name='var3')
registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn=None,
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var2:0': [fn2],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer3',
layer_vars=[var3],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var3],
varname_to_count_contribution_fn={
'var3:0': [fn1],
},
),
]
expected_varname_to_contribution_counts_fns = {
'var2:0': [fn2],
'var3:0': [fn1],
}
varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)
)
self.assertEqual(
varname_to_contribution_counts_fns,
expected_varname_to_contribution_counts_fns,
)
def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
def fn1(_):
return 1.0
def fn2(_):
return 2.0
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var1')
registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn={
'var1:0': [fn1],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var1:0': [fn2],
},
),
]
with self.assertRaises(ValueError):
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View file

@ -22,7 +22,7 @@ InputKwargs = Mapping[str, Any]
SparseGradient = tf.IndexedSlices SparseGradient = tf.IndexedSlices
ContributionCountHistogram = tf.SparseTensor ContributionCountHistogram = tf.SparseTensor
ContributionCountHistogramFn = Callable[ ContributionCountHistogramFn = Callable[
[SparseGradient], Mapping[str, ContributionCountHistogram] [SparseGradient], ContributionCountHistogram
] ]
NumMicrobatches = int | tf.Tensor NumMicrobatches = int | tf.Tensor
SparsityPreservingNoiseLayerRegistryFunction = Callable[ SparsityPreservingNoiseLayerRegistryFunction = Callable[