forked from 626_privacy/tensorflow_privacy
Sparsity Preserving DP-SGD in TF Privacy
Add function to merge varname_to_contribution_count_fn maps from different layers. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 664906202
This commit is contained in:
parent
38d80cae92
commit
93c7e54327
4 changed files with 153 additions and 7 deletions
|
@ -5,12 +5,19 @@ licenses(["notice"])
|
||||||
py_library(
|
py_library(
|
||||||
name = "sparse_noise_utils",
|
name = "sparse_noise_utils",
|
||||||
srcs = ["sparse_noise_utils.py"],
|
srcs = ["sparse_noise_utils.py"],
|
||||||
|
deps = [
|
||||||
|
":type_aliases",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "sparse_noise_utils_test",
|
name = "sparse_noise_utils_test",
|
||||||
srcs = ["sparse_noise_utils_test.py"],
|
srcs = ["sparse_noise_utils_test.py"],
|
||||||
deps = [":sparse_noise_utils"],
|
deps = [
|
||||||
|
":sparse_noise_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -16,10 +16,13 @@
|
||||||
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import collections
|
||||||
from typing import Mapping, Optional, Sequence
|
from typing import Mapping, Optional, Sequence
|
||||||
|
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
|
||||||
import tensorflow_probability as tfp
|
import tensorflow_probability as tfp
|
||||||
|
|
||||||
|
|
||||||
|
@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list: Sequence[
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput
|
||||||
|
],
|
||||||
|
trainable_vars: Sequence[tf.Variable],
|
||||||
|
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
|
||||||
|
"""Extracts a map of contribution count fns from generator outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
|
||||||
|
instances returned by
|
||||||
|
`gradient_clipping_utils.model_forward_backward_pass`.
|
||||||
|
trainable_vars: A list of trainable variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `dict` from varname to contribution counts functions
|
||||||
|
"""
|
||||||
|
if trainable_vars is not None:
|
||||||
|
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||||
|
# itself is not hashable.
|
||||||
|
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||||
|
|
||||||
|
varname_to_contribution_counts_fns = collections.defaultdict(list)
|
||||||
|
for registry_fn_output in registry_fn_outputs_list:
|
||||||
|
if trainable_vars is None or any(
|
||||||
|
w.ref() in trainable_vars
|
||||||
|
for w in registry_fn_output.layer_trainable_weights
|
||||||
|
):
|
||||||
|
if registry_fn_output.varname_to_count_contribution_fn is not None:
|
||||||
|
duplicate_varnames = set(
|
||||||
|
registry_fn_output.varname_to_count_contribution_fn.keys()
|
||||||
|
) & set(varname_to_contribution_counts_fns.keys())
|
||||||
|
if duplicate_varnames:
|
||||||
|
raise ValueError(
|
||||||
|
'Duplicate varnames: {duplicate_varnames} found in contribution'
|
||||||
|
' counts functions.'
|
||||||
|
)
|
||||||
|
varname_to_contribution_counts_fns.update(
|
||||||
|
registry_fn_output.varname_to_count_contribution_fn
|
||||||
|
)
|
||||||
|
return varname_to_contribution_counts_fns
|
||||||
|
|
||||||
|
|
||||||
def get_contribution_counts(
|
def get_contribution_counts(
|
||||||
trainable_vars: list[tf.Variable],
|
trainable_vars: Sequence[tf.Variable],
|
||||||
grads: list[tf.Tensor],
|
grads: Sequence[tf.Tensor],
|
||||||
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
|
varname_to_contribution_counts_fns: Mapping[
|
||||||
) -> list[tf.Tensor | None]:
|
str, type_aliases.ContributionCountHistogramFn
|
||||||
|
],
|
||||||
|
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
|
||||||
"""Gets the contribution counts for each variable in the Model.
|
"""Gets the contribution counts for each variable in the Model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainable_vars: A list of the trainable variables in the Model.
|
trainable_vars: A list of trainable variables.
|
||||||
grads: A corresponding list of gradients for each trainable variable.
|
grads: A corresponding list of gradients for each trainable variable.
|
||||||
varname_to_contribution_counts_fns: A mapping from variable name to a list
|
varname_to_contribution_counts_fns: A mapping from variable name to a list
|
||||||
of functions to get the contribution counts for that variable.
|
of functions to get the contribution counts for that variable.
|
||||||
|
|
|
@ -17,6 +17,7 @@ from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -436,6 +437,96 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
|
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_extract_varname_to_contribution_counts_fns(self):
|
||||||
|
def fn1(_):
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def fn2(_):
|
||||||
|
return 2.0
|
||||||
|
|
||||||
|
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||||
|
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
|
||||||
|
var3 = tf.Variable(tf.ones((1, 2)), name='var3')
|
||||||
|
|
||||||
|
registry_fn_outputs_list = [
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer1',
|
||||||
|
layer_vars=[var1],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var1],
|
||||||
|
varname_to_count_contribution_fn=None,
|
||||||
|
),
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer2',
|
||||||
|
layer_vars=[var2],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var2],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var2:0': [fn2],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer3',
|
||||||
|
layer_vars=[var3],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var3],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var3:0': [fn1],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
expected_varname_to_contribution_counts_fns = {
|
||||||
|
'var2:0': [fn2],
|
||||||
|
'var3:0': [fn1],
|
||||||
|
}
|
||||||
|
varname_to_contribution_counts_fns = (
|
||||||
|
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
trainable_vars=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
varname_to_contribution_counts_fns,
|
||||||
|
expected_varname_to_contribution_counts_fns,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
|
||||||
|
def fn1(_):
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def fn2(_):
|
||||||
|
return 2.0
|
||||||
|
|
||||||
|
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||||
|
var2 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||||
|
|
||||||
|
registry_fn_outputs_list = [
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer1',
|
||||||
|
layer_vars=[var1],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var1],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var1:0': [fn1],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer2',
|
||||||
|
layer_vars=[var2],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var2],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var1:0': [fn2],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
trainable_vars=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -22,7 +22,7 @@ InputKwargs = Mapping[str, Any]
|
||||||
SparseGradient = tf.IndexedSlices
|
SparseGradient = tf.IndexedSlices
|
||||||
ContributionCountHistogram = tf.SparseTensor
|
ContributionCountHistogram = tf.SparseTensor
|
||||||
ContributionCountHistogramFn = Callable[
|
ContributionCountHistogramFn = Callable[
|
||||||
[SparseGradient], Mapping[str, ContributionCountHistogram]
|
[SparseGradient], ContributionCountHistogram
|
||||||
]
|
]
|
||||||
NumMicrobatches = int | tf.Tensor
|
NumMicrobatches = int | tf.Tensor
|
||||||
SparsityPreservingNoiseLayerRegistryFunction = Callable[
|
SparsityPreservingNoiseLayerRegistryFunction = Callable[
|
||||||
|
|
Loading…
Reference in a new issue