Sparsity Preserving DP-SGD in TF Privacy [5 of 5]

Integrate sparsity preserving noise into DP Keras Model with fast gradient clipping.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 666849100
This commit is contained in:
A. Unique TensorFlower 2024-08-23 10:45:37 -07:00
parent 93c7e54327
commit b3963971e3
6 changed files with 86 additions and 16 deletions

View file

@ -21,11 +21,11 @@ import tensorflow as tf
# Tensorflow aliases.
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]]
InputTensors = PackedTensors
OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]]
OutputTensors = Union[Tensor, Iterable[Tensor]]
BatchSize = Union[int, tf.Tensor]

View file

@ -18,6 +18,8 @@ py_library(
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils",
],
)

View file

@ -13,16 +13,37 @@
# limitations under the License.
"""Keras Model for vectorized dpsgd with XLA acceleration."""
import dataclasses
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
@dataclasses.dataclass
class SparsityPreservingDPSGDConfig:
"""Config for adding sparsity preserving noise to the gradients."""
# The ratio of how the noise is split between partition selection and gradient
# noise.
sparse_selection_ratio: float = 0.0
# The threshold to use for private partition selection.
sparse_selection_threshold: int = 100
# A `LayerRegistry` instance containing functions that help compute
# contribution counts for sparse layers. See
# `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
# more details.
sparse_selection_layer_registry: snlr.LayerRegistry | None = None
def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
@ -104,6 +125,9 @@ def make_dp_model_class(cls):
num_microbatches=None,
use_xla=True,
layer_registry=None,
sparsity_preserving_dpsgd_config: (
SparsityPreservingDPSGDConfig | None
) = None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs,
):
@ -118,6 +142,9 @@ def make_dp_model_class(cls):
help compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
sparsity_preserving_dpsgd_config: If provided, uses partition selection
and sparse noise for privatizing sparse gradients for layers in
`sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
@ -127,6 +154,8 @@ def make_dp_model_class(cls):
self._layer_registry = layer_registry
self._clipping_loss = None
self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config
# Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API.
# In particular, boolean values supplied to `use_xla` in the earlier API
@ -276,11 +305,16 @@ def make_dp_model_class(cls):
# microbatches is done here.
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
sparse_noise_layer_registry = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_layer_registry = (
self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry
)
registry_generator_fn = (
gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=self._layer_registry,
sparse_noise_layer_registry=None,
sparse_noise_layer_registry=sparse_noise_layer_registry,
num_microbatches=num_microbatches,
)
)
@ -310,14 +344,53 @@ def make_dp_model_class(cls):
)
)
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0:
noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None
contribution_counts = None
if self._sparsity_preserving_dpsgd_config is not None:
logging.info('Using sparse noise.')
varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
self.trainable_variables,
)
)
contribution_counts = sparse_noise_utils.get_contribution_counts(
self.trainable_variables,
clipped_grads,
varname_to_contribution_counts_fns,
)
noise_multiplier_sparse, noise_multiplier = (
sparse_noise_utils.split_noise_multiplier(
noise_multiplier,
self._sparsity_preserving_dpsgd_config.sparse_selection_ratio,
contribution_counts,
)
)
logging.info(
'Split noise multiplier for gradient noise: %s and partition'
' selection: %s',
noise_multiplier,
noise_multiplier_sparse,
)
if noise_multiplier > 0:
sparse_noise_config = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
sparse_noise_multiplier=noise_multiplier_sparse,
sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold,
sparse_contribution_counts=contribution_counts,
)
grads = noise_utils.add_aggregate_noise(
clipped_grads,
num_microbatches,
self._l2_norm_clip,
self._noise_multiplier,
noise_multiplier,
loss_reduction=None,
loss_model=self,
sparse_noise_config=sparse_noise_config,
)
else:
grads = clipped_grads

View file

@ -166,7 +166,7 @@ def sample_true_positive_indices(
tf.shape(contribution_count_values),
mean=0.0,
stddev=noise_multiplier,
dtype=tf.float32,
dtype=contribution_count_values.dtype,
)
)
noised_contribution_counts_indices = contribution_counts.indices[
@ -281,7 +281,7 @@ def add_sparse_gradient_noise(
"""
filtered_grad_values = tf.gather(grad, indices)
sparse_noise_values = tf.random.normal(
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev
tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev
)
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
return tf.IndexedSlices(
@ -362,15 +362,10 @@ def get_contribution_counts(
if var.name not in varname_to_contribution_counts_fns:
contribution_counts_list.append(None)
continue
contribution_counts_fns = varname_to_contribution_counts_fns[var.name]
if not contribution_counts_fns or not contribution_counts_fns[0]:
contribution_counts_fn = varname_to_contribution_counts_fns[var.name]
if not contribution_counts_fn:
contribution_counts_list.append(None)
continue
if len(contribution_counts_fns) > 1:
raise NotImplementedError(
'Sparse noise is not supported for shared weight variables.'
)
contribution_counts_fn = contribution_counts_fns[0]
contribution_counts = contribution_counts_fn(grad)
contribution_counts_list.append(contribution_counts)

View file

@ -369,7 +369,7 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
tf.ones((1, 2)),
]
varname_to_contribution_counts_fns = {
'var1:0': [lambda grad: 1.0],
'var1:0': lambda grad: 1.0,
'var2:0': None,
}
contribution_counts = sparse_noise_utils.get_contribution_counts(

View file

@ -19,7 +19,7 @@ import tensorflow as tf
InputArgs = Sequence[Any]
InputKwargs = Mapping[str, Any]
SparseGradient = tf.IndexedSlices
SparseGradient = tf.IndexedSlices | tf.SparseTensor
ContributionCountHistogram = tf.SparseTensor
ContributionCountHistogramFn = Callable[
[SparseGradient], ContributionCountHistogram