Sparsity Preserving DP-SGD in TF Privacy

Add function to merge varname_to_contribution_count_fn maps from different layers.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 664906202
This commit is contained in:
A. Unique TensorFlower 2024-08-19 11:43:50 -07:00
parent 38d80cae92
commit 93c7e54327
4 changed files with 153 additions and 7 deletions

View file

@ -5,12 +5,19 @@ licenses(["notice"])
py_library(
name = "sparse_noise_utils",
srcs = ["sparse_noise_utils.py"],
deps = [
":type_aliases",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
)
py_test(
name = "sparse_noise_utils_test",
srcs = ["sparse_noise_utils_test.py"],
deps = [":sparse_noise_utils"],
deps = [
":sparse_noise_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
)
py_library(

View file

@ -16,10 +16,13 @@
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
"""
import collections
from typing import Mapping, Optional, Sequence
from scipy import stats
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
import tensorflow_probability as tfp
@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
)
def extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
trainable_vars: Sequence[tf.Variable],
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
"""Extracts a map of contribution count fns from generator outputs.
Args:
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
instances returned by
`gradient_clipping_utils.model_forward_backward_pass`.
trainable_vars: A list of trainable variables.
Returns:
A `dict` from varname to contribution counts functions
"""
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
varname_to_contribution_counts_fns = collections.defaultdict(list)
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
if registry_fn_output.varname_to_count_contribution_fn is not None:
duplicate_varnames = set(
registry_fn_output.varname_to_count_contribution_fn.keys()
) & set(varname_to_contribution_counts_fns.keys())
if duplicate_varnames:
raise ValueError(
'Duplicate varnames: {duplicate_varnames} found in contribution'
' counts functions.'
)
varname_to_contribution_counts_fns.update(
registry_fn_output.varname_to_count_contribution_fn
)
return varname_to_contribution_counts_fns
def get_contribution_counts(
trainable_vars: list[tf.Variable],
grads: list[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
) -> list[tf.Tensor | None]:
trainable_vars: Sequence[tf.Variable],
grads: Sequence[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[
str, type_aliases.ContributionCountHistogramFn
],
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
"""Gets the contribution counts for each variable in the Model.
Args:
trainable_vars: A list of the trainable variables in the Model.
trainable_vars: A list of trainable variables.
grads: A corresponding list of gradients for each trainable variable.
varname_to_contribution_counts_fns: A mapping from variable name to a list
of functions to get the contribution counts for that variable.

View file

@ -17,6 +17,7 @@ from absl.testing import parameterized
import numpy as np
from scipy import stats
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
@ -436,6 +437,96 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
)
def test_extract_varname_to_contribution_counts_fns(self):
def fn1(_):
return 1.0
def fn2(_):
return 2.0
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
var3 = tf.Variable(tf.ones((1, 2)), name='var3')
registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn=None,
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var2:0': [fn2],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer3',
layer_vars=[var3],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var3],
varname_to_count_contribution_fn={
'var3:0': [fn1],
},
),
]
expected_varname_to_contribution_counts_fns = {
'var2:0': [fn2],
'var3:0': [fn1],
}
varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)
)
self.assertEqual(
varname_to_contribution_counts_fns,
expected_varname_to_contribution_counts_fns,
)
def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
def fn1(_):
return 1.0
def fn2(_):
return 2.0
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var1')
registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn={
'var1:0': [fn1],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var1:0': [fn2],
},
),
]
with self.assertRaises(ValueError):
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)
if __name__ == '__main__':
tf.test.main()

View file

@ -22,7 +22,7 @@ InputKwargs = Mapping[str, Any]
SparseGradient = tf.IndexedSlices
ContributionCountHistogram = tf.SparseTensor
ContributionCountHistogramFn = Callable[
[SparseGradient], Mapping[str, ContributionCountHistogram]
[SparseGradient], ContributionCountHistogram
]
NumMicrobatches = int | tf.Tensor
SparsityPreservingNoiseLayerRegistryFunction = Callable[