Sparsity Preserving DP-SGD in TF Privacy [5 of 5]

Integrate sparsity preserving noise into DP Keras Model with fast gradient clipping.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 666849100
This commit is contained in:
A. Unique TensorFlower 2024-08-23 10:45:37 -07:00
parent 93c7e54327
commit b3963971e3
6 changed files with 86 additions and 16 deletions

View file

@ -21,11 +21,11 @@ import tensorflow as tf
# Tensorflow aliases. # Tensorflow aliases.
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor] Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]] PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]]
InputTensors = PackedTensors InputTensors = PackedTensors
OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]] OutputTensors = Union[Tensor, Iterable[Tensor]]
BatchSize = Union[int, tf.Tensor] BatchSize = Union[int, tf.Tensor]

View file

@ -18,6 +18,8 @@ py_library(
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils",
], ],
) )

View file

@ -13,16 +13,37 @@
# limitations under the License. # limitations under the License.
"""Keras Model for vectorized dpsgd with XLA acceleration.""" """Keras Model for vectorized dpsgd with XLA acceleration."""
import dataclasses
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
_PRIVATIZED_LOSS_NAME = 'privatized_loss' _PRIVATIZED_LOSS_NAME = 'privatized_loss'
@dataclasses.dataclass
class SparsityPreservingDPSGDConfig:
"""Config for adding sparsity preserving noise to the gradients."""
# The ratio of how the noise is split between partition selection and gradient
# noise.
sparse_selection_ratio: float = 0.0
# The threshold to use for private partition selection.
sparse_selection_threshold: int = 100
# A `LayerRegistry` instance containing functions that help compute
# contribution counts for sparse layers. See
# `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
# more details.
sparse_selection_layer_registry: snlr.LayerRegistry | None = None
def make_dp_model_class(cls): def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
@ -104,6 +125,9 @@ def make_dp_model_class(cls):
num_microbatches=None, num_microbatches=None,
use_xla=True, use_xla=True,
layer_registry=None, layer_registry=None,
sparsity_preserving_dpsgd_config: (
SparsityPreservingDPSGDConfig | None
) = None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs, **kwargs,
): ):
@ -118,6 +142,9 @@ def make_dp_model_class(cls):
help compute gradient norms quickly. See help compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details. more details.
sparsity_preserving_dpsgd_config: If provided, uses partition selection
and sparse noise for privatizing sparse gradients for layers in
`sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
*args: These will be passed on to the base class `__init__` method. *args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method.
""" """
@ -127,6 +154,8 @@ def make_dp_model_class(cls):
self._layer_registry = layer_registry self._layer_registry = layer_registry
self._clipping_loss = None self._clipping_loss = None
self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config
# Given that `num_microbatches` was added as an argument after the fact, # Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API. # this check helps detect unintended calls to the earlier API.
# In particular, boolean values supplied to `use_xla` in the earlier API # In particular, boolean values supplied to `use_xla` in the earlier API
@ -276,11 +305,16 @@ def make_dp_model_class(cls):
# microbatches is done here. # microbatches is done here.
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
sparse_noise_layer_registry = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_layer_registry = (
self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry
)
registry_generator_fn = ( registry_generator_fn = (
gradient_clipping_utils.get_registry_generator_fn( gradient_clipping_utils.get_registry_generator_fn(
tape=tape, tape=tape,
layer_registry=self._layer_registry, layer_registry=self._layer_registry,
sparse_noise_layer_registry=None, sparse_noise_layer_registry=sparse_noise_layer_registry,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
) )
) )
@ -310,14 +344,53 @@ def make_dp_model_class(cls):
) )
) )
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0: noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None
contribution_counts = None
if self._sparsity_preserving_dpsgd_config is not None:
logging.info('Using sparse noise.')
varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
self.trainable_variables,
)
)
contribution_counts = sparse_noise_utils.get_contribution_counts(
self.trainable_variables,
clipped_grads,
varname_to_contribution_counts_fns,
)
noise_multiplier_sparse, noise_multiplier = (
sparse_noise_utils.split_noise_multiplier(
noise_multiplier,
self._sparsity_preserving_dpsgd_config.sparse_selection_ratio,
contribution_counts,
)
)
logging.info(
'Split noise multiplier for gradient noise: %s and partition'
' selection: %s',
noise_multiplier,
noise_multiplier_sparse,
)
if noise_multiplier > 0:
sparse_noise_config = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
sparse_noise_multiplier=noise_multiplier_sparse,
sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold,
sparse_contribution_counts=contribution_counts,
)
grads = noise_utils.add_aggregate_noise( grads = noise_utils.add_aggregate_noise(
clipped_grads, clipped_grads,
num_microbatches, num_microbatches,
self._l2_norm_clip, self._l2_norm_clip,
self._noise_multiplier, noise_multiplier,
loss_reduction=None, loss_reduction=None,
loss_model=self, loss_model=self,
sparse_noise_config=sparse_noise_config,
) )
else: else:
grads = clipped_grads grads = clipped_grads

View file

@ -166,7 +166,7 @@ def sample_true_positive_indices(
tf.shape(contribution_count_values), tf.shape(contribution_count_values),
mean=0.0, mean=0.0,
stddev=noise_multiplier, stddev=noise_multiplier,
dtype=tf.float32, dtype=contribution_count_values.dtype,
) )
) )
noised_contribution_counts_indices = contribution_counts.indices[ noised_contribution_counts_indices = contribution_counts.indices[
@ -281,7 +281,7 @@ def add_sparse_gradient_noise(
""" """
filtered_grad_values = tf.gather(grad, indices) filtered_grad_values = tf.gather(grad, indices)
sparse_noise_values = tf.random.normal( sparse_noise_values = tf.random.normal(
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev
) )
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
return tf.IndexedSlices( return tf.IndexedSlices(
@ -362,15 +362,10 @@ def get_contribution_counts(
if var.name not in varname_to_contribution_counts_fns: if var.name not in varname_to_contribution_counts_fns:
contribution_counts_list.append(None) contribution_counts_list.append(None)
continue continue
contribution_counts_fns = varname_to_contribution_counts_fns[var.name] contribution_counts_fn = varname_to_contribution_counts_fns[var.name]
if not contribution_counts_fns or not contribution_counts_fns[0]: if not contribution_counts_fn:
contribution_counts_list.append(None) contribution_counts_list.append(None)
continue continue
if len(contribution_counts_fns) > 1:
raise NotImplementedError(
'Sparse noise is not supported for shared weight variables.'
)
contribution_counts_fn = contribution_counts_fns[0]
contribution_counts = contribution_counts_fn(grad) contribution_counts = contribution_counts_fn(grad)
contribution_counts_list.append(contribution_counts) contribution_counts_list.append(contribution_counts)

View file

@ -369,7 +369,7 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
tf.ones((1, 2)), tf.ones((1, 2)),
] ]
varname_to_contribution_counts_fns = { varname_to_contribution_counts_fns = {
'var1:0': [lambda grad: 1.0], 'var1:0': lambda grad: 1.0,
'var2:0': None, 'var2:0': None,
} }
contribution_counts = sparse_noise_utils.get_contribution_counts( contribution_counts = sparse_noise_utils.get_contribution_counts(

View file

@ -19,7 +19,7 @@ import tensorflow as tf
InputArgs = Sequence[Any] InputArgs = Sequence[Any]
InputKwargs = Mapping[str, Any] InputKwargs = Mapping[str, Any]
SparseGradient = tf.IndexedSlices SparseGradient = tf.IndexedSlices | tf.SparseTensor
ContributionCountHistogram = tf.SparseTensor ContributionCountHistogram = tf.SparseTensor
ContributionCountHistogramFn = Callable[ ContributionCountHistogramFn = Callable[
[SparseGradient], ContributionCountHistogram [SparseGradient], ContributionCountHistogram