diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py index b064c69..c4033a8 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py @@ -21,11 +21,11 @@ import tensorflow as tf # Tensorflow aliases. Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor] -PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]] +PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]] InputTensors = PackedTensors -OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]] +OutputTensors = Union[Tensor, Iterable[Tensor]] BatchSize = Union[int, tf.Tensor] diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD index ed3e9fc..a8883f6 100644 --- a/tensorflow_privacy/privacy/keras_models/BUILD +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -18,6 +18,8 @@ py_library( "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils", ], ) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 472f175..0879880 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -13,16 +13,37 @@ # limitations under the License. """Keras Model for vectorized dpsgd with XLA acceleration.""" +import dataclasses + from absl import logging import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils +from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr +from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils + _PRIVATIZED_LOSS_NAME = 'privatized_loss' +@dataclasses.dataclass +class SparsityPreservingDPSGDConfig: + """Config for adding sparsity preserving noise to the gradients.""" + + # The ratio of how the noise is split between partition selection and gradient + # noise. + sparse_selection_ratio: float = 0.0 + # The threshold to use for private partition selection. + sparse_selection_threshold: int = 100 + # A `LayerRegistry` instance containing functions that help compute + # contribution counts for sparse layers. See + # `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for + # more details. + sparse_selection_layer_registry: snlr.LayerRegistry | None = None + + def make_dp_model_class(cls): """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" @@ -104,6 +125,9 @@ def make_dp_model_class(cls): num_microbatches=None, use_xla=True, layer_registry=None, + sparsity_preserving_dpsgd_config: ( + SparsityPreservingDPSGDConfig | None + ) = None, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args **kwargs, ): @@ -118,6 +142,9 @@ def make_dp_model_class(cls): help compute gradient norms quickly. See `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for more details. + sparsity_preserving_dpsgd_config: If provided, uses partition selection + and sparse noise for privatizing sparse gradients for layers in + `sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`. *args: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method. """ @@ -127,6 +154,8 @@ def make_dp_model_class(cls): self._layer_registry = layer_registry self._clipping_loss = None + self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config + # Given that `num_microbatches` was added as an argument after the fact, # this check helps detect unintended calls to the earlier API. # In particular, boolean values supplied to `use_xla` in the earlier API @@ -276,11 +305,16 @@ def make_dp_model_class(cls): # microbatches is done here. tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + sparse_noise_layer_registry = None + if self._sparsity_preserving_dpsgd_config is not None: + sparse_noise_layer_registry = ( + self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry + ) registry_generator_fn = ( gradient_clipping_utils.get_registry_generator_fn( tape=tape, layer_registry=self._layer_registry, - sparse_noise_layer_registry=None, + sparse_noise_layer_registry=sparse_noise_layer_registry, num_microbatches=num_microbatches, ) ) @@ -310,14 +344,53 @@ def make_dp_model_class(cls): ) ) output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss - if self._noise_multiplier > 0: + noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None + contribution_counts = None + if self._sparsity_preserving_dpsgd_config is not None: + logging.info('Using sparse noise.') + + varname_to_contribution_counts_fns = ( + sparse_noise_utils.extract_varname_to_contribution_counts_fns( + registry_fn_outputs_list, + self.trainable_variables, + ) + ) + contribution_counts = sparse_noise_utils.get_contribution_counts( + self.trainable_variables, + clipped_grads, + varname_to_contribution_counts_fns, + ) + + noise_multiplier_sparse, noise_multiplier = ( + sparse_noise_utils.split_noise_multiplier( + noise_multiplier, + self._sparsity_preserving_dpsgd_config.sparse_selection_ratio, + contribution_counts, + ) + ) + logging.info( + 'Split noise multiplier for gradient noise: %s and partition' + ' selection: %s', + noise_multiplier, + noise_multiplier_sparse, + ) + + if noise_multiplier > 0: + sparse_noise_config = None + if self._sparsity_preserving_dpsgd_config is not None: + sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig( + sparse_noise_multiplier=noise_multiplier_sparse, + sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold, + sparse_contribution_counts=contribution_counts, + ) grads = noise_utils.add_aggregate_noise( clipped_grads, num_microbatches, self._l2_norm_clip, - self._noise_multiplier, + noise_multiplier, loss_reduction=None, loss_model=self, + sparse_noise_config=sparse_noise_config, ) else: grads = clipped_grads diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py index 9830150..2915f43 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py @@ -166,7 +166,7 @@ def sample_true_positive_indices( tf.shape(contribution_count_values), mean=0.0, stddev=noise_multiplier, - dtype=tf.float32, + dtype=contribution_count_values.dtype, ) ) noised_contribution_counts_indices = contribution_counts.indices[ @@ -281,7 +281,7 @@ def add_sparse_gradient_noise( """ filtered_grad_values = tf.gather(grad, indices) sparse_noise_values = tf.random.normal( - filtered_grad_values.shape, mean=0.0, stddev=noise_stddev + tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev ) filtered_noised_grad_values = filtered_grad_values + sparse_noise_values return tf.IndexedSlices( @@ -362,15 +362,10 @@ def get_contribution_counts( if var.name not in varname_to_contribution_counts_fns: contribution_counts_list.append(None) continue - contribution_counts_fns = varname_to_contribution_counts_fns[var.name] - if not contribution_counts_fns or not contribution_counts_fns[0]: + contribution_counts_fn = varname_to_contribution_counts_fns[var.name] + if not contribution_counts_fn: contribution_counts_list.append(None) continue - if len(contribution_counts_fns) > 1: - raise NotImplementedError( - 'Sparse noise is not supported for shared weight variables.' - ) - contribution_counts_fn = contribution_counts_fns[0] contribution_counts = contribution_counts_fn(grad) contribution_counts_list.append(contribution_counts) diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py index 35b26ad..b9a0229 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py @@ -369,7 +369,7 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase): tf.ones((1, 2)), ] varname_to_contribution_counts_fns = { - 'var1:0': [lambda grad: 1.0], + 'var1:0': lambda grad: 1.0, 'var2:0': None, } contribution_counts = sparse_noise_utils.get_contribution_counts( diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py index 48283b3..03f9550 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py @@ -19,7 +19,7 @@ import tensorflow as tf InputArgs = Sequence[Any] InputKwargs = Mapping[str, Any] -SparseGradient = tf.IndexedSlices +SparseGradient = tf.IndexedSlices | tf.SparseTensor ContributionCountHistogram = tf.SparseTensor ContributionCountHistogramFn = Callable[ [SparseGradient], ContributionCountHistogram