Sparsity Preserving DP-SGD in TF Privacy
Add support for calculating contribution counts to registry function for sparsity preserving noise. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 662162597
This commit is contained in:
parent
e42b574465
commit
bf6cf4dec9
6 changed files with 120 additions and 1 deletions
|
@ -46,6 +46,8 @@ py_library(
|
||||||
":common_manip_utils",
|
":common_manip_utils",
|
||||||
":layer_registry",
|
":layer_registry",
|
||||||
":type_aliases",
|
":type_aliases",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,7 +57,11 @@ py_test(
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 8,
|
shard_count = 8,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":gradient_clipping_utils"],
|
deps = [
|
||||||
|
":gradient_clipping_utils",
|
||||||
|
":layer_registry",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -164,6 +164,7 @@ def compute_gradient_norms(
|
||||||
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||||
tape=tape,
|
tape=tape,
|
||||||
layer_registry=layer_registry,
|
layer_registry=layer_registry,
|
||||||
|
sparse_noise_layer_registry=None,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
)
|
)
|
||||||
layer_grad_vars, generator_outputs_list = (
|
layer_grad_vars, generator_outputs_list = (
|
||||||
|
|
|
@ -132,6 +132,7 @@ def _run_model_forward_backward_pass(
|
||||||
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||||
tape=tape,
|
tape=tape,
|
||||||
layer_registry=layer_registry.make_default_layer_registry(),
|
layer_registry=layer_registry.make_default_layer_registry(),
|
||||||
|
sparse_noise_layer_registry=None,
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
)
|
)
|
||||||
layer_grad_vars, registry_fn_outputs_list = (
|
layer_grad_vars, registry_fn_outputs_list = (
|
||||||
|
|
|
@ -22,6 +22,8 @@ import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
@ -29,6 +31,9 @@ class RegistryGeneratorFunctionOutput:
|
||||||
layer_id: str
|
layer_id: str
|
||||||
layer_vars: Optional[Sequence[tf.Variable]]
|
layer_vars: Optional[Sequence[tf.Variable]]
|
||||||
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
|
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
|
||||||
|
varname_to_count_contribution_fn: Optional[
|
||||||
|
dict[str, sn_type_aliases.ContributionCountHistogramFn]
|
||||||
|
]
|
||||||
layer_trainable_weights: Optional[Sequence[tf.Variable]]
|
layer_trainable_weights: Optional[Sequence[tf.Variable]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,6 +51,7 @@ def has_internal_compute_graph(input_object: Any):
|
||||||
def get_registry_generator_fn(
|
def get_registry_generator_fn(
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
layer_registry: lr.LayerRegistry,
|
layer_registry: lr.LayerRegistry,
|
||||||
|
sparse_noise_layer_registry: snlr.LayerRegistry,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
|
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
|
||||||
"""Creates the generator function for `model_forward_backward_pass()`.
|
"""Creates the generator function for `model_forward_backward_pass()`.
|
||||||
|
@ -58,6 +64,10 @@ def get_registry_generator_fn(
|
||||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||||
trainable
|
trainable
|
||||||
|
sparse_noise_layer_registry: A `LayerRegistry` instance containing functions
|
||||||
|
that help compute contribution counts for sparse noise. See
|
||||||
|
`tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
|
||||||
|
more details.
|
||||||
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||||
microbatches. If not None, indicates that the loss is grouped into
|
microbatches. If not None, indicates that the loss is grouped into
|
||||||
num_microbatches (in this case, the batch dimension needs to be a multiple
|
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||||
|
@ -83,6 +93,16 @@ def get_registry_generator_fn(
|
||||||
'be used for efficient gradient clipping.'
|
'be used for efficient gradient clipping.'
|
||||||
% layer_instance.__class__.__name__
|
% layer_instance.__class__.__name__
|
||||||
)
|
)
|
||||||
|
varname_to_count_contribution_fn = None
|
||||||
|
if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem(
|
||||||
|
layer_instance
|
||||||
|
):
|
||||||
|
count_contribution_registry_fn = sparse_noise_layer_registry.lookup(
|
||||||
|
layer_instance
|
||||||
|
)
|
||||||
|
varname_to_count_contribution_fn = count_contribution_registry_fn(
|
||||||
|
layer_instance, args, kwargs, num_microbatches
|
||||||
|
)
|
||||||
registry_fn = layer_registry.lookup(layer_instance)
|
registry_fn = layer_registry.lookup(layer_instance)
|
||||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||||
layer_instance, args, kwargs, tape, num_microbatches
|
layer_instance, args, kwargs, tape, num_microbatches
|
||||||
|
@ -91,6 +111,7 @@ def get_registry_generator_fn(
|
||||||
layer_id=str(id(layer_instance)),
|
layer_id=str(id(layer_instance)),
|
||||||
layer_vars=layer_vars,
|
layer_vars=layer_vars,
|
||||||
layer_sqr_norm_fn=layer_sqr_norm_fn,
|
layer_sqr_norm_fn=layer_sqr_norm_fn,
|
||||||
|
varname_to_count_contribution_fn=varname_to_count_contribution_fn,
|
||||||
layer_trainable_weights=layer_instance.trainable_weights,
|
layer_trainable_weights=layer_instance.trainable_weights,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -17,6 +17,8 @@ from typing import Any
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -175,5 +177,92 @@ class GenerateOutputsUsingCoreKerasLayers(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _get_sparse_layer_registry(self):
|
||||||
|
def count_contribution_fn(_):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def registry_fn(*_):
|
||||||
|
return {'var': count_contribution_fn}
|
||||||
|
|
||||||
|
registry = snlr.LayerRegistry()
|
||||||
|
registry.insert(tf.keras.layers.Embedding, registry_fn)
|
||||||
|
return registry, count_contribution_fn
|
||||||
|
|
||||||
|
def _get_layer_registry(self):
|
||||||
|
var = tf.Variable(1.0)
|
||||||
|
output = tf.ones((1, 1))
|
||||||
|
|
||||||
|
def sqr_norm_fn(_):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def registry_fn(*_):
|
||||||
|
return [var], output, sqr_norm_fn
|
||||||
|
|
||||||
|
registry = lr.LayerRegistry()
|
||||||
|
registry.insert(tf.keras.layers.Embedding, registry_fn)
|
||||||
|
registry.insert(tf.keras.layers.Dense, registry_fn)
|
||||||
|
return registry, var, output, sqr_norm_fn
|
||||||
|
|
||||||
|
def test_registry_generator_fn(self):
|
||||||
|
inputs = tf.constant([[0, 1]])
|
||||||
|
model = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.Embedding(10, 1),
|
||||||
|
tf.keras.layers.Dense(1),
|
||||||
|
])
|
||||||
|
|
||||||
|
sparse_layer_registry, count_contribution_fn = (
|
||||||
|
self._get_sparse_layer_registry()
|
||||||
|
)
|
||||||
|
layer_registry, var, output, sqr_norm_fn = self._get_layer_registry()
|
||||||
|
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||||
|
tape=tf.GradientTape(),
|
||||||
|
layer_registry=layer_registry,
|
||||||
|
sparse_noise_layer_registry=sparse_layer_registry,
|
||||||
|
num_microbatches=None,
|
||||||
|
)
|
||||||
|
embedding_layer = model.layers[0]
|
||||||
|
out, embedding_registry_generator_fn_output = registry_generator_fn(
|
||||||
|
embedding_layer,
|
||||||
|
[inputs],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
expected_embedding_registry_generator_fn_output = (
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id=str(id(embedding_layer)),
|
||||||
|
layer_vars=[var],
|
||||||
|
layer_sqr_norm_fn=sqr_norm_fn,
|
||||||
|
varname_to_count_contribution_fn={'var': count_contribution_fn},
|
||||||
|
layer_trainable_weights=embedding_layer.trainable_weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
embedding_registry_generator_fn_output,
|
||||||
|
expected_embedding_registry_generator_fn_output,
|
||||||
|
)
|
||||||
|
self.assertEqual(out, output)
|
||||||
|
dense_layer = model.layers[1]
|
||||||
|
out, dense_registry_generator_fn_output = registry_generator_fn(
|
||||||
|
dense_layer,
|
||||||
|
[inputs],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
expected_dense_registry_generator_fn_output = (
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id=str(id(dense_layer)),
|
||||||
|
layer_vars=[var],
|
||||||
|
layer_sqr_norm_fn=sqr_norm_fn,
|
||||||
|
varname_to_count_contribution_fn=None,
|
||||||
|
layer_trainable_weights=dense_layer.trainable_weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
dense_registry_generator_fn_output,
|
||||||
|
expected_dense_registry_generator_fn_output,
|
||||||
|
)
|
||||||
|
self.assertEqual(out, output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -280,6 +280,7 @@ def make_dp_model_class(cls):
|
||||||
gradient_clipping_utils.get_registry_generator_fn(
|
gradient_clipping_utils.get_registry_generator_fn(
|
||||||
tape=tape,
|
tape=tape,
|
||||||
layer_registry=self._layer_registry,
|
layer_registry=self._layer_registry,
|
||||||
|
sparse_noise_layer_registry=None,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue