forked from 626_privacy/tensorflow_privacy
Sparsity Preserving DP-SGD in TF Privacy [5 of 5]
Integrate sparsity preserving noise into DP Keras Model with fast gradient clipping. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 666849100
This commit is contained in:
parent
93c7e54327
commit
b3963971e3
6 changed files with 86 additions and 16 deletions
|
@ -21,11 +21,11 @@ import tensorflow as tf
|
||||||
# Tensorflow aliases.
|
# Tensorflow aliases.
|
||||||
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
|
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
|
||||||
|
|
||||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
|
PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]]
|
||||||
|
|
||||||
InputTensors = PackedTensors
|
InputTensors = PackedTensors
|
||||||
|
|
||||||
OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]]
|
OutputTensors = Union[Tensor, Iterable[Tensor]]
|
||||||
|
|
||||||
BatchSize = Union[int, tf.Tensor]
|
BatchSize = Union[int, tf.Tensor]
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@ py_library(
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,16 +13,37 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Keras Model for vectorized dpsgd with XLA acceleration."""
|
"""Keras Model for vectorized dpsgd with XLA acceleration."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||||
|
|
||||||
|
|
||||||
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
|
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SparsityPreservingDPSGDConfig:
|
||||||
|
"""Config for adding sparsity preserving noise to the gradients."""
|
||||||
|
|
||||||
|
# The ratio of how the noise is split between partition selection and gradient
|
||||||
|
# noise.
|
||||||
|
sparse_selection_ratio: float = 0.0
|
||||||
|
# The threshold to use for private partition selection.
|
||||||
|
sparse_selection_threshold: int = 100
|
||||||
|
# A `LayerRegistry` instance containing functions that help compute
|
||||||
|
# contribution counts for sparse layers. See
|
||||||
|
# `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
|
||||||
|
# more details.
|
||||||
|
sparse_selection_layer_registry: snlr.LayerRegistry | None = None
|
||||||
|
|
||||||
|
|
||||||
def make_dp_model_class(cls):
|
def make_dp_model_class(cls):
|
||||||
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
||||||
|
|
||||||
|
@ -104,6 +125,9 @@ def make_dp_model_class(cls):
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
use_xla=True,
|
use_xla=True,
|
||||||
layer_registry=None,
|
layer_registry=None,
|
||||||
|
sparsity_preserving_dpsgd_config: (
|
||||||
|
SparsityPreservingDPSGDConfig | None
|
||||||
|
) = None,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -118,6 +142,9 @@ def make_dp_model_class(cls):
|
||||||
help compute gradient norms quickly. See
|
help compute gradient norms quickly. See
|
||||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||||
more details.
|
more details.
|
||||||
|
sparsity_preserving_dpsgd_config: If provided, uses partition selection
|
||||||
|
and sparse noise for privatizing sparse gradients for layers in
|
||||||
|
`sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
**kwargs: These will be passed on to the base class `__init__` method.
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
"""
|
"""
|
||||||
|
@ -127,6 +154,8 @@ def make_dp_model_class(cls):
|
||||||
self._layer_registry = layer_registry
|
self._layer_registry = layer_registry
|
||||||
self._clipping_loss = None
|
self._clipping_loss = None
|
||||||
|
|
||||||
|
self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config
|
||||||
|
|
||||||
# Given that `num_microbatches` was added as an argument after the fact,
|
# Given that `num_microbatches` was added as an argument after the fact,
|
||||||
# this check helps detect unintended calls to the earlier API.
|
# this check helps detect unintended calls to the earlier API.
|
||||||
# In particular, boolean values supplied to `use_xla` in the earlier API
|
# In particular, boolean values supplied to `use_xla` in the earlier API
|
||||||
|
@ -276,11 +305,16 @@ def make_dp_model_class(cls):
|
||||||
# microbatches is done here.
|
# microbatches is done here.
|
||||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||||
|
|
||||||
|
sparse_noise_layer_registry = None
|
||||||
|
if self._sparsity_preserving_dpsgd_config is not None:
|
||||||
|
sparse_noise_layer_registry = (
|
||||||
|
self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry
|
||||||
|
)
|
||||||
registry_generator_fn = (
|
registry_generator_fn = (
|
||||||
gradient_clipping_utils.get_registry_generator_fn(
|
gradient_clipping_utils.get_registry_generator_fn(
|
||||||
tape=tape,
|
tape=tape,
|
||||||
layer_registry=self._layer_registry,
|
layer_registry=self._layer_registry,
|
||||||
sparse_noise_layer_registry=None,
|
sparse_noise_layer_registry=sparse_noise_layer_registry,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -310,14 +344,53 @@ def make_dp_model_class(cls):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
||||||
if self._noise_multiplier > 0:
|
noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None
|
||||||
|
contribution_counts = None
|
||||||
|
if self._sparsity_preserving_dpsgd_config is not None:
|
||||||
|
logging.info('Using sparse noise.')
|
||||||
|
|
||||||
|
varname_to_contribution_counts_fns = (
|
||||||
|
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
self.trainable_variables,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
contribution_counts = sparse_noise_utils.get_contribution_counts(
|
||||||
|
self.trainable_variables,
|
||||||
|
clipped_grads,
|
||||||
|
varname_to_contribution_counts_fns,
|
||||||
|
)
|
||||||
|
|
||||||
|
noise_multiplier_sparse, noise_multiplier = (
|
||||||
|
sparse_noise_utils.split_noise_multiplier(
|
||||||
|
noise_multiplier,
|
||||||
|
self._sparsity_preserving_dpsgd_config.sparse_selection_ratio,
|
||||||
|
contribution_counts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
'Split noise multiplier for gradient noise: %s and partition'
|
||||||
|
' selection: %s',
|
||||||
|
noise_multiplier,
|
||||||
|
noise_multiplier_sparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
if noise_multiplier > 0:
|
||||||
|
sparse_noise_config = None
|
||||||
|
if self._sparsity_preserving_dpsgd_config is not None:
|
||||||
|
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
|
||||||
|
sparse_noise_multiplier=noise_multiplier_sparse,
|
||||||
|
sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold,
|
||||||
|
sparse_contribution_counts=contribution_counts,
|
||||||
|
)
|
||||||
grads = noise_utils.add_aggregate_noise(
|
grads = noise_utils.add_aggregate_noise(
|
||||||
clipped_grads,
|
clipped_grads,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
self._l2_norm_clip,
|
self._l2_norm_clip,
|
||||||
self._noise_multiplier,
|
noise_multiplier,
|
||||||
loss_reduction=None,
|
loss_reduction=None,
|
||||||
loss_model=self,
|
loss_model=self,
|
||||||
|
sparse_noise_config=sparse_noise_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
grads = clipped_grads
|
grads = clipped_grads
|
||||||
|
|
|
@ -166,7 +166,7 @@ def sample_true_positive_indices(
|
||||||
tf.shape(contribution_count_values),
|
tf.shape(contribution_count_values),
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
stddev=noise_multiplier,
|
stddev=noise_multiplier,
|
||||||
dtype=tf.float32,
|
dtype=contribution_count_values.dtype,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
noised_contribution_counts_indices = contribution_counts.indices[
|
noised_contribution_counts_indices = contribution_counts.indices[
|
||||||
|
@ -281,7 +281,7 @@ def add_sparse_gradient_noise(
|
||||||
"""
|
"""
|
||||||
filtered_grad_values = tf.gather(grad, indices)
|
filtered_grad_values = tf.gather(grad, indices)
|
||||||
sparse_noise_values = tf.random.normal(
|
sparse_noise_values = tf.random.normal(
|
||||||
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev
|
tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev
|
||||||
)
|
)
|
||||||
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
|
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
|
||||||
return tf.IndexedSlices(
|
return tf.IndexedSlices(
|
||||||
|
@ -362,15 +362,10 @@ def get_contribution_counts(
|
||||||
if var.name not in varname_to_contribution_counts_fns:
|
if var.name not in varname_to_contribution_counts_fns:
|
||||||
contribution_counts_list.append(None)
|
contribution_counts_list.append(None)
|
||||||
continue
|
continue
|
||||||
contribution_counts_fns = varname_to_contribution_counts_fns[var.name]
|
contribution_counts_fn = varname_to_contribution_counts_fns[var.name]
|
||||||
if not contribution_counts_fns or not contribution_counts_fns[0]:
|
if not contribution_counts_fn:
|
||||||
contribution_counts_list.append(None)
|
contribution_counts_list.append(None)
|
||||||
continue
|
continue
|
||||||
if len(contribution_counts_fns) > 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
'Sparse noise is not supported for shared weight variables.'
|
|
||||||
)
|
|
||||||
contribution_counts_fn = contribution_counts_fns[0]
|
|
||||||
contribution_counts = contribution_counts_fn(grad)
|
contribution_counts = contribution_counts_fn(grad)
|
||||||
contribution_counts_list.append(contribution_counts)
|
contribution_counts_list.append(contribution_counts)
|
||||||
|
|
||||||
|
|
|
@ -369,7 +369,7 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
tf.ones((1, 2)),
|
tf.ones((1, 2)),
|
||||||
]
|
]
|
||||||
varname_to_contribution_counts_fns = {
|
varname_to_contribution_counts_fns = {
|
||||||
'var1:0': [lambda grad: 1.0],
|
'var1:0': lambda grad: 1.0,
|
||||||
'var2:0': None,
|
'var2:0': None,
|
||||||
}
|
}
|
||||||
contribution_counts = sparse_noise_utils.get_contribution_counts(
|
contribution_counts = sparse_noise_utils.get_contribution_counts(
|
||||||
|
|
|
@ -19,7 +19,7 @@ import tensorflow as tf
|
||||||
|
|
||||||
InputArgs = Sequence[Any]
|
InputArgs = Sequence[Any]
|
||||||
InputKwargs = Mapping[str, Any]
|
InputKwargs = Mapping[str, Any]
|
||||||
SparseGradient = tf.IndexedSlices
|
SparseGradient = tf.IndexedSlices | tf.SparseTensor
|
||||||
ContributionCountHistogram = tf.SparseTensor
|
ContributionCountHistogram = tf.SparseTensor
|
||||||
ContributionCountHistogramFn = Callable[
|
ContributionCountHistogramFn = Callable[
|
||||||
[SparseGradient], ContributionCountHistogram
|
[SparseGradient], ContributionCountHistogram
|
||||||
|
|
Loading…
Reference in a new issue