forked from 626_privacy/tensorflow_privacy
Sparsity Preserving DP-SGD in TF Privacy
Add function to merge varname_to_contribution_count_fn maps from different layers. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 664906202
This commit is contained in:
parent
38d80cae92
commit
93c7e54327
4 changed files with 153 additions and 7 deletions
|
@ -5,12 +5,19 @@ licenses(["notice"])
|
|||
py_library(
|
||||
name = "sparse_noise_utils",
|
||||
srcs = ["sparse_noise_utils.py"],
|
||||
deps = [
|
||||
":type_aliases",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sparse_noise_utils_test",
|
||||
srcs = ["sparse_noise_utils_test.py"],
|
||||
deps = [":sparse_noise_utils"],
|
||||
deps = [
|
||||
":sparse_noise_utils",
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
@ -16,10 +16,13 @@
|
|||
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
||||
"""
|
||||
|
||||
import collections
|
||||
from typing import Mapping, Optional, Sequence
|
||||
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
|
||||
|
@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
|
|||
)
|
||||
|
||||
|
||||
def extract_varname_to_contribution_counts_fns(
|
||||
registry_fn_outputs_list: Sequence[
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput
|
||||
],
|
||||
trainable_vars: Sequence[tf.Variable],
|
||||
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
|
||||
"""Extracts a map of contribution count fns from generator outputs.
|
||||
|
||||
Args:
|
||||
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
|
||||
instances returned by
|
||||
`gradient_clipping_utils.model_forward_backward_pass`.
|
||||
trainable_vars: A list of trainable variables.
|
||||
|
||||
Returns:
|
||||
A `dict` from varname to contribution counts functions
|
||||
"""
|
||||
if trainable_vars is not None:
|
||||
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||
# itself is not hashable.
|
||||
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||
|
||||
varname_to_contribution_counts_fns = collections.defaultdict(list)
|
||||
for registry_fn_output in registry_fn_outputs_list:
|
||||
if trainable_vars is None or any(
|
||||
w.ref() in trainable_vars
|
||||
for w in registry_fn_output.layer_trainable_weights
|
||||
):
|
||||
if registry_fn_output.varname_to_count_contribution_fn is not None:
|
||||
duplicate_varnames = set(
|
||||
registry_fn_output.varname_to_count_contribution_fn.keys()
|
||||
) & set(varname_to_contribution_counts_fns.keys())
|
||||
if duplicate_varnames:
|
||||
raise ValueError(
|
||||
'Duplicate varnames: {duplicate_varnames} found in contribution'
|
||||
' counts functions.'
|
||||
)
|
||||
varname_to_contribution_counts_fns.update(
|
||||
registry_fn_output.varname_to_count_contribution_fn
|
||||
)
|
||||
return varname_to_contribution_counts_fns
|
||||
|
||||
|
||||
def get_contribution_counts(
|
||||
trainable_vars: list[tf.Variable],
|
||||
grads: list[tf.Tensor],
|
||||
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
|
||||
) -> list[tf.Tensor | None]:
|
||||
trainable_vars: Sequence[tf.Variable],
|
||||
grads: Sequence[tf.Tensor],
|
||||
varname_to_contribution_counts_fns: Mapping[
|
||||
str, type_aliases.ContributionCountHistogramFn
|
||||
],
|
||||
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
|
||||
"""Gets the contribution counts for each variable in the Model.
|
||||
|
||||
Args:
|
||||
trainable_vars: A list of the trainable variables in the Model.
|
||||
trainable_vars: A list of trainable variables.
|
||||
grads: A corresponding list of gradients for each trainable variable.
|
||||
varname_to_contribution_counts_fns: A mapping from variable name to a list
|
||||
of functions to get the contribution counts for that variable.
|
||||
|
|
|
@ -17,6 +17,7 @@ from absl.testing import parameterized
|
|||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||
|
||||
|
||||
|
@ -436,6 +437,96 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
|
||||
)
|
||||
|
||||
def test_extract_varname_to_contribution_counts_fns(self):
|
||||
def fn1(_):
|
||||
return 1.0
|
||||
|
||||
def fn2(_):
|
||||
return 2.0
|
||||
|
||||
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
|
||||
var3 = tf.Variable(tf.ones((1, 2)), name='var3')
|
||||
|
||||
registry_fn_outputs_list = [
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||
layer_id='layer1',
|
||||
layer_vars=[var1],
|
||||
layer_sqr_norm_fn=None,
|
||||
layer_trainable_weights=[var1],
|
||||
varname_to_count_contribution_fn=None,
|
||||
),
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||
layer_id='layer2',
|
||||
layer_vars=[var2],
|
||||
layer_sqr_norm_fn=None,
|
||||
layer_trainable_weights=[var2],
|
||||
varname_to_count_contribution_fn={
|
||||
'var2:0': [fn2],
|
||||
},
|
||||
),
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||
layer_id='layer3',
|
||||
layer_vars=[var3],
|
||||
layer_sqr_norm_fn=None,
|
||||
layer_trainable_weights=[var3],
|
||||
varname_to_count_contribution_fn={
|
||||
'var3:0': [fn1],
|
||||
},
|
||||
),
|
||||
]
|
||||
expected_varname_to_contribution_counts_fns = {
|
||||
'var2:0': [fn2],
|
||||
'var3:0': [fn1],
|
||||
}
|
||||
varname_to_contribution_counts_fns = (
|
||||
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||
registry_fn_outputs_list,
|
||||
trainable_vars=None,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
varname_to_contribution_counts_fns,
|
||||
expected_varname_to_contribution_counts_fns,
|
||||
)
|
||||
|
||||
def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
|
||||
def fn1(_):
|
||||
return 1.0
|
||||
|
||||
def fn2(_):
|
||||
return 2.0
|
||||
|
||||
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||
var2 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||
|
||||
registry_fn_outputs_list = [
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||
layer_id='layer1',
|
||||
layer_vars=[var1],
|
||||
layer_sqr_norm_fn=None,
|
||||
layer_trainable_weights=[var1],
|
||||
varname_to_count_contribution_fn={
|
||||
'var1:0': [fn1],
|
||||
},
|
||||
),
|
||||
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||
layer_id='layer2',
|
||||
layer_vars=[var2],
|
||||
layer_sqr_norm_fn=None,
|
||||
layer_trainable_weights=[var2],
|
||||
varname_to_count_contribution_fn={
|
||||
'var1:0': [fn2],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||
registry_fn_outputs_list,
|
||||
trainable_vars=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
|
@ -22,7 +22,7 @@ InputKwargs = Mapping[str, Any]
|
|||
SparseGradient = tf.IndexedSlices
|
||||
ContributionCountHistogram = tf.SparseTensor
|
||||
ContributionCountHistogramFn = Callable[
|
||||
[SparseGradient], Mapping[str, ContributionCountHistogram]
|
||||
[SparseGradient], ContributionCountHistogram
|
||||
]
|
||||
NumMicrobatches = int | tf.Tensor
|
||||
SparsityPreservingNoiseLayerRegistryFunction = Callable[
|
||||
|
|
Loading…
Reference in a new issue