Compare commits

...

10 commits

Author SHA1 Message Date
A. Unique TensorFlower
d965556ebb Disable MLIR bridge for the test points that MLIR bridge silently fails
PiperOrigin-RevId: 676660290
2024-09-19 19:51:40 -07:00
Galen Andrew
e8856835a6 Internal change
PiperOrigin-RevId: 671064874
2024-09-04 12:42:12 -07:00
A. Unique TensorFlower
e3fd3afdf8 Clarify documentation of labels_train/test usage wrt loss_train/test.
PiperOrigin-RevId: 670641843
2024-09-03 11:39:27 -07:00
William Kong
66d05a22a3 Fix a gradient clipping bug for layer normalization layers with microbatch axes.
The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug.

PiperOrigin-RevId: 669413064
2024-08-30 12:41:11 -07:00
A. Unique TensorFlower
b3963971e3 Sparsity Preserving DP-SGD in TF Privacy [5 of 5]
Integrate sparsity preserving noise into DP Keras Model with fast gradient clipping.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 666849100
2024-08-23 10:46:12 -07:00
A. Unique TensorFlower
93c7e54327 Sparsity Preserving DP-SGD in TF Privacy
Add function to merge varname_to_contribution_count_fn maps from different layers.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 664906202
2024-08-19 11:44:43 -07:00
A. Unique TensorFlower
38d80cae92 Automated Code Change
PiperOrigin-RevId: 662904771
2024-08-14 07:03:41 -07:00
A. Unique TensorFlower
bf6cf4dec9 Sparsity Preserving DP-SGD in TF Privacy
Add support for calculating contribution counts to registry function for sparsity preserving noise.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 662162597
2024-08-12 11:22:06 -07:00
A. Unique TensorFlower
e42b574465 Sparsity Preserving DP-SGD in TF Privacy
Add support for adding sparsity preserving noise in add_aggregate_noise

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 662148309
2024-08-12 10:45:28 -07:00
A. Unique TensorFlower
09c68750d7 Sparsity Preserving DP-SGD in TF Privacy
Refactor model_forward_backward_pass out of compute_gradients to allow for other optimizations such as sparsity preserving noise to integrate with it.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660924829
2024-08-08 11:52:02 -07:00
20 changed files with 855 additions and 203 deletions

View file

@ -37,7 +37,7 @@ setuptools.setup(
install_requires=[ install_requires=[
'absl-py>=1.0,==1.*', 'absl-py>=1.0,==1.*',
'dm-tree==0.1.8', 'dm-tree==0.1.8',
'dp-accounting==0.4.4', 'dp-accounting==0.4.4', # TODO(b/364653784)
'numpy~=1.21', 'numpy~=1.21',
'packaging~=22.0', 'packaging~=22.0',
'scikit-learn>=1.0,==1.*', 'scikit-learn>=1.0,==1.*',

View file

@ -43,8 +43,11 @@ py_library(
srcs = ["gradient_clipping_utils.py"], srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":common_manip_utils",
":layer_registry", ":layer_registry",
":type_aliases", ":type_aliases",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases",
], ],
) )
@ -54,7 +57,11 @@ py_test(
python_version = "PY3", python_version = "PY3",
shard_count = 8, shard_count = 8,
srcs_version = "PY3", srcs_version = "PY3",
deps = [":gradient_clipping_utils"], deps = [
":gradient_clipping_utils",
":layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
],
) )
py_library( py_library(
@ -83,6 +90,7 @@ py_library(
py_library( py_library(
name = "noise_utils", name = "noise_utils",
srcs = ["noise_utils.py"], srcs = ["noise_utils.py"],
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"],
) )
py_test( py_test(
@ -94,6 +102,7 @@ py_test(
deps = [ deps = [
":clip_grads", ":clip_grads",
":common_test_utils", ":common_test_utils",
":gradient_clipping_utils",
":layer_registry", ":layer_registry",
":type_aliases", ":type_aliases",
], ],

View file

@ -22,7 +22,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
""" """
import collections import collections
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from typing import Optional from typing import Optional
import tensorflow as tf import tensorflow as tf
@ -32,73 +32,81 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def _infer_per_example_loss_fn(model: tf.keras.Model): def _compute_gradient_norms_internal(
"""Infer the per-example loss from model config.""" registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
trainable_vars: Optional[Sequence[tf.Variable]] = None,
) -> tf.Tensor:
"""Computes the per-example loss gradient norms for given data.
def _convert(loss_fn): Args:
loss_config = loss_fn.get_config() registry_fn_outputs_list: A sequence of RegistryGeneratorFunctionOutput
loss_config['reduction'] = tf.keras.losses.Reduction.NONE containing information required to compute the gradient norms and
return loss_fn.from_config(loss_config) contribution counts. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
layer_grad_vars: A mapping of layer id to a list of gradients for each
trainable variable in the layer. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
trainable_vars: The list of variables included in computing the gradient
norm. When a layer has multiple variables, we include all the variables if
any of the variables is in the list. If `trainable_vars` is None, all the
variables are included.
model_loss = model.loss Returns:
if isinstance(model_loss, tf.keras.losses.Loss): A scalar vector, whose i-th entry is the norm of the gradient of the i-th
return _convert(model_loss) weighted example loss (when num_microbatches is None) or the norm of the
elif isinstance(model_loss, dict): gradient of the i-th microbatch loss (define as a mean over the microbatch).
# Note that we cannot call the public method `.get_compile_config()` because Note that when the loss is weighted (`weight_batch` is not None), weights
# it calls a numpy function, which is not supported inside a `tf.function` are applied prior to clipping.
# wrapped function.
compile_config = model._compile_config.config # pylint: disable=protected-access Raises:
if compile_config is None: ValueError: If `layer_grad_vars` is empty.
raise ValueError('Model must be compiled for loss function conversion') ValueError: If the number of gradients for a layer is not equal to the
# Does a weighted mean of the configured losses. Note that we cannot build number of squared norm functions for that layer.
# from the config of the compiled loss because (i) it builds a """
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s if trainable_vars is not None:
# during its construction, (ii) non-unique `tf.Variables` cannot be used # Create a set using `ref()` for fast set membership check. tf.Variable
# inside a `tf.function`, which is usually where this function is used. # itself is not hashable.
if 'loss_weights' not in compile_config: trainable_vars = set([v.ref() for v in trainable_vars])
raise ValueError(
'Models with multiple loss must have corresponding loss weights for' layer_sqr_norm_fns = collections.defaultdict(list)
' loss function conversion' # The case of shared weights:
# If a layer is called k times, it will appear k times in filtered_outputs,
# with the same id, but potentially with different v and f. The code below
# groups filtered_outputs by layer_id, so we can correctly compute gradient
# norms. The gradient norm of a layer that occurs k times is computed as
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
registry_fn_output.layer_sqr_norm_fn
) )
weights = compile_config['loss_weights']
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
num_losses = len(weights)
def _per_example_loss_fn(y_true, y_pred, sample_weight=None): if not layer_grad_vars:
loss_values = [] raise ValueError('The gradient list cannot be empty.')
if model_loss.keys() - y_pred.keys(): sqr_norm_list = []
raise ValueError( for layer_id in layer_sqr_norm_fns.keys():
'y_pred must contain the same keys and the model losses, but ' fns = layer_sqr_norm_fns[layer_id]
'got %s and %s' % (y_pred.keys(), model_loss.keys()) grads = layer_grad_vars[layer_id]
) # Number of duplicates for this layer in `filtered_outputs`.
if model_loss.keys() - y_true.keys(): num_passes = len(fns)
raise ValueError( if len(fns) != len(grads):
'y_true must contain the same keys and the model losses, but ' raise ValueError(
'got %s and %s' % (y_true.keys(), model_loss.keys()) 'There must be as many gradients as squared norm functions.'
) )
if sample_weight is not None: # See go/dp-sgd-shared-weights for more details.
if model_loss.keys() - sample_weight.keys(): for fn, grad in zip(fns, grads):
raise ValueError( sqr_norm_list.append(num_passes * fn(grad))
'sample_weight must contain the same keys and the model losses,' sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
' but got %s and %s' % (y_true.keys(), model_loss.keys()) gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
) return gradient_norms
for k in y_true.keys():
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
sgl_value = (
weights[k]
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
/ num_losses
)
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
return tf.math.add_n(loss_values)
return _per_example_loss_fn
else:
raise ValueError(
'Unsupported type for loss function conversion: {}'.format(
type(model_loss)
)
)
def compute_gradient_norms( def compute_gradient_norms(
@ -110,7 +118,7 @@ def compute_gradient_norms(
per_example_loss_fn: Optional[type_aliases.LossFn] = None, per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[Sequence[tf.Variable]] = None, trainable_vars: Optional[Sequence[tf.Variable]] = None,
): ) -> tf.Tensor:
"""Computes the per-example loss gradient norms for given data. """Computes the per-example loss gradient norms for given data.
Applies a variant of the approach given in Applies a variant of the approach given in
@ -154,90 +162,28 @@ def compute_gradient_norms(
""" """
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape, layer_registry, num_microbatches tape=tape,
layer_registry=layer_registry,
sparse_noise_layer_registry=None,
num_microbatches=num_microbatches,
) )
# First loop computes the model outputs, summed loss, and generator outputs. layer_grad_vars, generator_outputs_list = (
with tape: gradient_clipping_utils.model_forward_backward_pass(
model_outputs, generator_outputs_list = ( tape=tape,
gradient_clipping_utils.model_forward_pass( input_model=input_model,
input_model, x_batch, generator_fn=registry_generator_fn x_batch=x_batch,
) y_batch=y_batch,
) registry_generator_fn=registry_generator_fn,
weight_batch=weight_batch,
# Ignore the original loss function's reduction to get per-example loss. per_example_loss_fn=per_example_loss_fn,
if per_example_loss_fn is None: num_microbatches=num_microbatches,
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
) )
if len(losses.shape) != 1:
raise NotImplementedError(
'The unreduced (or per-example) loss needs to have a shape of length '
'one, but received an unreduced loss of shape length %s'
% len(losses.shape)
)
if num_microbatches is not None:
losses = tf.reduce_mean(
common_manip_utils.maybe_add_microbatch_axis(
losses, num_microbatches
),
axis=1,
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None]
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
layer_vars = collections.defaultdict(list)
layer_sqr_norm_fns = collections.defaultdict(list)
# The case of shared weights:
# If a layer is called k times, it will appear k times in filtered_outputs,
# with the same id, but potentially with different v and f. The code below
# groups filtered_outputs by layer_id, so we can correctly compute gradient
# norms. The gradient norm of a layer that occurs k times is computed as
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for registry_fn_output in filtered_outputs:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_vars[registry_fn_output.layer_id].append(
registry_fn_output.layer_vars
)
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
registry_fn_output.layer_sqr_norm_fn
)
# Second loop evaluates the squared L2 norm functions and appends the results.
layer_grad_vars = tape.gradient(
summed_loss,
layer_vars,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
) )
if not layer_grad_vars: return _compute_gradient_norms_internal(
raise ValueError('The gradient list cannot be empty.') registry_fn_outputs_list=generator_outputs_list,
sqr_norm_list = [] layer_grad_vars=layer_grad_vars,
for layer_id in layer_sqr_norm_fns.keys(): trainable_vars=trainable_vars,
fns = layer_sqr_norm_fns[layer_id] )
grads = layer_grad_vars[layer_id]
# Number of duplicates for this layer in `filtered_outputs`.
num_passes = len(fns)
if len(fns) != len(grads):
raise ValueError(
'There must be as many gradients as squared norm functions.'
)
# See go/dp-sgd-shared-weights for more details.
for fn, grad in zip(fns, grads):
sqr_norm_list.append(num_passes * fn(grad))
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
@ -267,14 +213,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
def compute_clipped_gradients_and_outputs( def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model, input_model: tf.keras.Model,
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
l2_norm_clip: float, l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: type_aliases.InputTensors, x_batch: type_aliases.InputTensors,
y_batch: type_aliases.OutputTensors, y_batch: type_aliases.OutputTensors,
weight_batch: Optional[tf.Tensor] = None, weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None, clipping_loss: Optional[type_aliases.LossFn] = None,
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]: ) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs. """Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
@ -287,15 +236,16 @@ def compute_clipped_gradients_and_outputs(
Args: Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. input_model: The `tf.keras.Model` from which to obtain the layers from.
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
containing information required to compute the gradient norms and
contribution counts. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
layer_grad_vars: A mapping of layer id to a list of gradients for each
trainablev ariable in the layer. Output from
`gradient_clipping_utils.model_forward_backward_pass()`.
l2_norm_clip: A `float` indicating the norm to which per-example gradients l2_norm_clip: A `float` indicating the norm to which per-example gradients
will be clipped. That is, all gradients of the per-example loss functions will be clipped. That is, all gradients of the per-example loss functions
will have norm at most `l2_norm_clip`. will have norm at most `l2_norm_clip`.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axes of each tensor must be the batch dimension. first axes of each tensor must be the batch dimension.
y_batch: An `OutputTensor` representing a batch of output labels. The first y_batch: An `OutputTensor` representing a batch of output labels. The first
@ -330,13 +280,9 @@ def compute_clipped_gradients_and_outputs(
) )
if clipping_loss is None: if clipping_loss is None:
clipping_loss = input_model.compiled_loss clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms( gradient_norms = _compute_gradient_norms_internal(
input_model, registry_fn_outputs_list=registry_fn_outputs_list,
layer_registry, layer_grad_vars=layer_grad_vars,
x_batch,
y_batch,
weight_batch,
num_microbatches=num_microbatches,
trainable_vars=input_model.trainable_variables, trainable_vars=input_model.trainable_variables,
) )
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms) clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)

View file

@ -19,6 +19,7 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -122,6 +123,30 @@ class CustomLayerTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
def _run_model_forward_backward_pass(
model: tf.keras.Model,
x_batch: type_aliases.InputTensors,
y_batch: type_aliases.OutputTensors,
):
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=layer_registry.make_default_layer_registry(),
sparse_noise_layer_registry=None,
num_microbatches=None,
)
layer_grad_vars, registry_fn_outputs_list = (
gradient_clipping_utils.model_forward_backward_pass(
tape=tape,
input_model=model,
x_batch=x_batch,
y_batch=y_batch,
registry_generator_fn=registry_generator_fn,
)
)
return layer_grad_vars, registry_fn_outputs_list
class ComputeClippedGradsAndOutputsTest( class ComputeClippedGradsAndOutputsTest(
tf.test.TestCase, parameterized.TestCase tf.test.TestCase, parameterized.TestCase
): ):
@ -153,13 +178,17 @@ class ComputeClippedGradsAndOutputsTest(
y_batch = tf.reshape( y_batch = tf.reshape(
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
) )
layer_grad_vars, registry_fn_outputs_list = (
_run_model_forward_backward_pass(self._model, x_batch, y_batch)
)
# Stop early for efficiency. # Stop early for efficiency.
if reduction == 'none': if reduction == 'none':
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
clip_grads.compute_clipped_gradients_and_outputs( clip_grads.compute_clipped_gradients_and_outputs(
self._model, self._model,
registry_fn_outputs_list,
layer_grad_vars,
l2_norm_clip, l2_norm_clip,
layer_registry.make_default_layer_registry(),
x_batch, x_batch,
y_batch, y_batch,
) )
@ -169,10 +198,12 @@ class ComputeClippedGradsAndOutputsTest(
y_pred = self._model(x_batch) y_pred = self._model(x_batch)
loss_value = loss_fn(y_pred, y_batch) loss_value = loss_fn(y_pred, y_batch)
true_grads = tape.gradient(loss_value, self._model.trainable_variables) true_grads = tape.gradient(loss_value, self._model.trainable_variables)
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
self._model, self._model,
registry_fn_outputs_list,
layer_grad_vars,
l2_norm_clip, l2_norm_clip,
layer_registry.make_default_layer_registry(),
x_batch, x_batch,
y_batch, y_batch,
) )

View file

@ -13,13 +13,17 @@
# limitations under the License. # limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms.""" """Utility functions that help in the computation of per-example gradient norms."""
import collections
from collections.abc import Callable, Sequence, Set from collections.abc import Callable, Sequence, Set
import dataclasses import dataclasses
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -27,6 +31,9 @@ class RegistryGeneratorFunctionOutput:
layer_id: str layer_id: str
layer_vars: Optional[Sequence[tf.Variable]] layer_vars: Optional[Sequence[tf.Variable]]
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction] layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
varname_to_count_contribution_fn: Optional[
dict[str, sn_type_aliases.ContributionCountHistogramFn]
]
layer_trainable_weights: Optional[Sequence[tf.Variable]] layer_trainable_weights: Optional[Sequence[tf.Variable]]
@ -44,6 +51,7 @@ def has_internal_compute_graph(input_object: Any):
def get_registry_generator_fn( def get_registry_generator_fn(
tape: tf.GradientTape, tape: tf.GradientTape,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
sparse_noise_layer_registry: snlr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]: ) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
"""Creates the generator function for `model_forward_backward_pass()`. """Creates the generator function for `model_forward_backward_pass()`.
@ -56,6 +64,10 @@ def get_registry_generator_fn(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the `output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable trainable
sparse_noise_layer_registry: A `LayerRegistry` instance containing functions
that help compute contribution counts for sparse noise. See
`tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
more details.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple num_microbatches (in this case, the batch dimension needs to be a multiple
@ -81,6 +93,16 @@ def get_registry_generator_fn(
'be used for efficient gradient clipping.' 'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__ % layer_instance.__class__.__name__
) )
varname_to_count_contribution_fn = None
if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem(
layer_instance
):
count_contribution_registry_fn = sparse_noise_layer_registry.lookup(
layer_instance
)
varname_to_count_contribution_fn = count_contribution_registry_fn(
layer_instance, args, kwargs, num_microbatches
)
registry_fn = layer_registry.lookup(layer_instance) registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches layer_instance, args, kwargs, tape, num_microbatches
@ -89,6 +111,7 @@ def get_registry_generator_fn(
layer_id=str(id(layer_instance)), layer_id=str(id(layer_instance)),
layer_vars=layer_vars, layer_vars=layer_vars,
layer_sqr_norm_fn=layer_sqr_norm_fn, layer_sqr_norm_fn=layer_sqr_norm_fn,
varname_to_count_contribution_fn=varname_to_count_contribution_fn,
layer_trainable_weights=layer_instance.trainable_weights, layer_trainable_weights=layer_instance.trainable_weights,
) )
else: else:
@ -98,6 +121,149 @@ def get_registry_generator_fn(
return registry_generator_fn return registry_generator_fn
def _infer_per_example_loss_fn(model: tf.keras.Model):
"""Infer the per-example loss from model config."""
def _convert(loss_fn):
loss_config = loss_fn.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
return loss_fn.from_config(loss_config)
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return _convert(model_loss)
elif isinstance(model_loss, dict):
# Note that we cannot call the public method `.get_compile_config()` because
# it calls a numpy function, which is not supported inside a `tf.function`
# wrapped function.
compile_config = model._compile_config.config # pylint: disable=protected-access
if compile_config is None:
raise ValueError('Model must be compiled for loss function conversion')
# Does a weighted mean of the configured losses. Note that we cannot build
# from the config of the compiled loss because (i) it builds a
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
# during its construction, (ii) non-unique `tf.Variables` cannot be used
# inside a `tf.function`, which is usually where this function is used.
if 'loss_weights' not in compile_config:
raise ValueError(
'Models with multiple loss must have corresponding loss weights for'
' loss function conversion'
)
weights = compile_config['loss_weights']
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
num_losses = len(weights)
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
loss_values = []
if model_loss.keys() - y_pred.keys():
raise ValueError(
'y_pred must contain the same keys and the model losses, but '
'got %s and %s' % (y_pred.keys(), model_loss.keys())
)
if model_loss.keys() - y_true.keys():
raise ValueError(
'y_true must contain the same keys and the model losses, but '
'got %s and %s' % (y_true.keys(), model_loss.keys())
)
if sample_weight is not None:
if model_loss.keys() - sample_weight.keys():
raise ValueError(
'sample_weight must contain the same keys and the model losses,'
' but got %s and %s' % (y_true.keys(), model_loss.keys())
)
for k in y_true.keys():
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
sgl_value = (
weights[k]
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
/ num_losses
)
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
return tf.math.add_n(loss_values)
return _per_example_loss_fn
else:
raise ValueError(
'Unsupported type for loss function conversion: {}'.format(
type(model_loss)
)
)
def model_forward_backward_pass(
tape: tf.GradientTape,
input_model: tf.keras.Model,
x_batch: type_aliases.InputTensors,
y_batch: type_aliases.OutputTensors,
registry_generator_fn: Optional[
Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]
],
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
trainable_vars: Optional[Sequence[tf.Variable]] = None,
) -> tuple[
dict[str, list[type_aliases.Tensor]], list[RegistryGeneratorFunctionOutput]
]:
"""Does a forward and backward pass of a model and returns useful intermediates."""
# First loop computes the model outputs, summed loss, and generator outputs.
with tape:
model_outputs, generator_outputs_list = model_forward_pass(
input_model, x_batch, generator_fn=registry_generator_fn
)
# Ignore the original loss function's reduction to get per-example loss.
if per_example_loss_fn is None:
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
"The unreduced (or per-example) loss's shape cannot be `None`"
)
if len(losses.shape) != 1:
raise NotImplementedError(
'The unreduced (or per-example) loss needs to have a shape of length '
'one, but received an unreduced loss of shape length %s'
% len(losses.shape)
)
if num_microbatches is not None:
losses = tf.reduce_mean(
common_manip_utils.maybe_add_microbatch_axis(
losses, num_microbatches
),
axis=1,
)
summed_loss = tf.reduce_sum(losses)
# Unwrap the generator outputs so that the next loop avoids duplicating
# backprop ops.
filtered_outputs = [t for t in generator_outputs_list if t is not None]
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
layer_vars = collections.defaultdict(list)
for registry_fn_output in filtered_outputs:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_vars[registry_fn_output.layer_id].append(
registry_fn_output.layer_vars
)
layer_grad_vars = tape.gradient(
summed_loss,
layer_vars,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
if not layer_grad_vars:
raise ValueError('The gradient list cannot be empty.')
return layer_grad_vars, filtered_outputs
def model_forward_pass( def model_forward_pass(
input_model: tf.keras.Model, input_model: tf.keras.Model,
inputs: type_aliases.PackedTensors, inputs: type_aliases.PackedTensors,

View file

@ -17,6 +17,8 @@ from typing import Any
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
# ============================================================================== # ==============================================================================
@ -175,5 +177,92 @@ class GenerateOutputsUsingCoreKerasLayers(
) )
class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase):
def _get_sparse_layer_registry(self):
def count_contribution_fn(_):
return None
def registry_fn(*_):
return {'var': count_contribution_fn}
registry = snlr.LayerRegistry()
registry.insert(tf.keras.layers.Embedding, registry_fn)
return registry, count_contribution_fn
def _get_layer_registry(self):
var = tf.Variable(1.0)
output = tf.ones((1, 1))
def sqr_norm_fn(_):
return None
def registry_fn(*_):
return [var], output, sqr_norm_fn
registry = lr.LayerRegistry()
registry.insert(tf.keras.layers.Embedding, registry_fn)
registry.insert(tf.keras.layers.Dense, registry_fn)
return registry, var, output, sqr_norm_fn
def test_registry_generator_fn(self):
inputs = tf.constant([[0, 1]])
model = tf.keras.Sequential([
tf.keras.layers.Embedding(10, 1),
tf.keras.layers.Dense(1),
])
sparse_layer_registry, count_contribution_fn = (
self._get_sparse_layer_registry()
)
layer_registry, var, output, sqr_norm_fn = self._get_layer_registry()
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tf.GradientTape(),
layer_registry=layer_registry,
sparse_noise_layer_registry=sparse_layer_registry,
num_microbatches=None,
)
embedding_layer = model.layers[0]
out, embedding_registry_generator_fn_output = registry_generator_fn(
embedding_layer,
[inputs],
{},
)
expected_embedding_registry_generator_fn_output = (
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id=str(id(embedding_layer)),
layer_vars=[var],
layer_sqr_norm_fn=sqr_norm_fn,
varname_to_count_contribution_fn={'var': count_contribution_fn},
layer_trainable_weights=embedding_layer.trainable_weights,
)
)
self.assertEqual(
embedding_registry_generator_fn_output,
expected_embedding_registry_generator_fn_output,
)
self.assertEqual(out, output)
dense_layer = model.layers[1]
out, dense_registry_generator_fn_output = registry_generator_fn(
dense_layer,
[inputs],
{},
)
expected_dense_registry_generator_fn_output = (
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id=str(id(dense_layer)),
layer_vars=[var],
layer_sqr_norm_fn=sqr_norm_fn,
varname_to_count_contribution_fn=None,
layer_trainable_weights=dense_layer.trainable_weights,
)
)
self.assertEqual(
dense_registry_generator_fn_output,
expected_dense_registry_generator_fn_output,
)
self.assertEqual(out, output)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View file

@ -14,10 +14,21 @@
"""Utility functions that help in adding noise to gradients.""" """Utility functions that help in adding noise to gradients."""
from collections.abc import Sequence from collections.abc import Sequence
import dataclasses
from typing import Literal, Optional from typing import Literal, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
@dataclasses.dataclass
class SparsityPreservingNoiseConfig:
"""Configuration for adding noise to gradients."""
sparse_noise_multiplier: float = 0.0
sparse_selection_threshold: int = 0
sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None
def _infer_loss_reduction_type(model: tf.keras.Model): def _infer_loss_reduction_type(model: tf.keras.Model):
@ -44,21 +55,53 @@ def _infer_loss_reduction_type(model: tf.keras.Model):
) )
def _add_dense_aggregate_noise(
grad: tf.Tensor,
noise_multiplier: float,
sensitivity: float,
) -> tf.Tensor:
"""Adds dense noise to a dense gradient."""
return grad + tf.random.normal(
tf.shape(grad), mean=0.0, stddev=noise_multiplier * sensitivity
)
def _add_sparse_aggregate_noise(
grad: tf.IndexedSlices,
contribution_counts: tf.SparseTensor,
noise_multiplier: float,
noise_multiplier_sparse: float,
sensitivity: float,
sparse_selection_threshold: int,
) -> tf.IndexedSlices:
"""Adds sparse noise to a sparse gradient."""
return sparse_noise_utils.add_sparse_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=noise_multiplier_sparse,
l2_norm_clip=sensitivity,
threshold=sparse_selection_threshold,
)
def add_aggregate_noise( def add_aggregate_noise(
clipped_grads: list[tf.Tensor], clipped_grads: list[tf.Tensor | tf.IndexedSlices],
batch_size: tf.Tensor, batch_size: tf.Tensor,
l2_norm_clip: float, l2_norm_clip: float,
noise_multiplier: float, noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None, loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None, loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]: sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None,
) -> Sequence[tf.Tensor | tf.IndexedSlices]:
"""Adds noise to a collection of clipped gradients. """Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the The magnitude of the noise depends on the aggregation strategy of the
input model's loss function. input model's loss function.
Args: Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing
the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'. `loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient). l2_norm_clip: Clipping norm (max L2 norm of each gradient).
@ -68,11 +111,14 @@ def add_aggregate_noise(
aggregation type must be inferred from `input_model.loss`. aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`. strategy from if `loss_reduction` is `None`.
sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing
the configuration for adding sparse noise. If None, all noise added is
dense.
Returns: Returns:
A list of tensors containing the clipped gradients, but with the right A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction amount of Gaussian or sparse Gaussain noise added to them (depending on
strategy of the loss function). the reduction strategy of the loss function).
Raises: Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if ValueError: If both `loss_model` and `loss_reduction` are `None` or if
@ -103,13 +149,36 @@ def add_aggregate_noise(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
) )
if sparse_noise_config is None:
sparse_contribution_counts = tf.nest.map_structure(
lambda x: None, clipped_grads
)
else:
sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts
scale = l2_norm_clip scale = l2_norm_clip
if loss_reduction == 'mean': if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32) scale /= tf.cast(batch_size, tf.float32)
def add_noise(g): def add_noise(grad, contribution_counts):
return g + tf.random.normal( if (
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale sparse_noise_config is not None
) and isinstance(grad, tf.IndexedSlices)
and contribution_counts is not None
):
return _add_sparse_aggregate_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=sparse_noise_config.sparse_noise_multiplier,
sensitivity=scale,
sparse_selection_threshold=sparse_noise_config.sparse_selection_threshold,
)
else:
return _add_dense_aggregate_noise(
grad=grad, noise_multiplier=noise_multiplier, sensitivity=scale
)
return tf.nest.map_structure(add_noise, clipped_grads) return tf.nest.map_structure(
add_noise, clipped_grads, sparse_contribution_counts
)

View file

@ -70,3 +70,95 @@ class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
computed_std = np.std(noised_grads[0] - clipped_grads[0]) computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std) self.assertNear(computed_std, expected_std, 0.1 * expected_std)
@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
sparse_noise_multiplier=[1.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_sparse_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
sparse_noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
sparse_grad = tf.IndexedSlices(
values=tf.ones((3, 4)),
indices=tf.constant([0, 3, 5]),
dense_shape=tf.constant([8, 4]),
)
sparse_grad_contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5]],
values=[10.0, 10.0, 20.0],
dense_shape=[8],
)
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
sparse_noise_multiplier=sparse_noise_multiplier,
sparse_selection_threshold=8,
sparse_contribution_counts=[None, sparse_grad_contribution_counts],
)
sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise(
clipped_grads=[dense_grad, sparse_grad],
batch_size=batch_size,
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
loss_model=linear_model,
sparse_noise_config=sparse_noise_config,
)
self.assertContainsSubset(
sparse_grad.indices.numpy().tolist(),
sparse_noised_grad.indices.numpy().tolist(),
)
sparse_noised_grad_dense = tf.scatter_nd(
tf.reshape(sparse_noised_grad.indices, (-1, 1)),
sparse_noised_grad.values,
shape=(8, 4),
).numpy()
sparse_noised_grad_valid_indices = sparse_noised_grad_dense[
sparse_grad.indices.numpy()
]
sparse_grad_values = sparse_grad.values.numpy()
self.assertTrue(
np.all(
np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values)
)
)
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
# The only measure that varies is the standard deviation of the variation.
expected_std = l2_norm_clip * noise_multiplier * scale
sparse_computed_std = np.std(
sparse_noised_grad_valid_indices - sparse_grad_values
)
self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std)
dense_computed_std = np.std(dense_noised_grad - dense_grad)
self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std)

View file

@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(embedding_test.GradNormTest): class GradNormTpuTest(embedding_test.GradNormTest):
def setUp(self): def setUp(self):
tf.config.experimental.disable_mlir_bridge()
super(embedding_test.GradNormTest, self).setUp() super(embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy() self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])

View file

@ -80,8 +80,11 @@ def layer_normalization_computation(
stacked_grads = tf.stack(grads, axis=-1) stacked_grads = tf.stack(grads, axis=-1)
if num_microbatches is not None: if num_microbatches is not None:
stacked_grads = common_manip_utils.maybe_add_microbatch_axis( stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
grads, num_microbatches stacked_grads, num_microbatches
) )
# We will need to sum over the new microbatch size axis (axis=1) in order
# to account for microbatch aggregation.
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
reduction_axes = tf.range(1, tf.rank(stacked_grads)) reduction_axes = tf.range(1, tf.rank(stacked_grads))
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)

View file

@ -134,7 +134,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
atol = 1e-1 if self.using_tpu else 1e-2 atol = 1e-1 if self.using_tpu else 1e-2
# Each batched input is a reshape of a `tf.range()` call. # Each batched input is a reshape of a `tf.range()` call.
batch_size = 2 batch_size = 6
example_size = np.prod(input_dims) example_size = np.prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32) example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims) x_batch = tf.reshape(example_values, [batch_size] + input_dims)
@ -147,7 +147,9 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
common_test_utils.assert_replica_values_are_close(self, true_norms) common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0] computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0] true_norms = true_norms.values[0]
self.assertEqual(tf.shape(computed_norms)[0], batch_size) self.assertEqual(
tf.shape(computed_norms)[0], num_microbatches or batch_size
)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol) self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)

View file

@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
def setUp(self): def setUp(self):
tf.config.experimental.disable_mlir_bridge()
super(nlp_on_device_embedding_test.GradNormTest, self).setUp() super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy() self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])

View file

@ -19,11 +19,13 @@ import tensorflow as tf
# Tensorflow aliases. # Tensorflow aliases.
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]] Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]]
InputTensors = PackedTensors InputTensors = PackedTensors
OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]] OutputTensors = Union[Tensor, Iterable[Tensor]]
BatchSize = Union[int, tf.Tensor] BatchSize = Union[int, tf.Tensor]

View file

@ -18,6 +18,8 @@ py_library(
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils",
], ],
) )
@ -35,6 +37,7 @@ py_test(
py_test( py_test(
name = "dp_keras_model_distributed_test", name = "dp_keras_model_distributed_test",
timeout = "long",
srcs = ["dp_keras_model_distributed_test.py"], srcs = ["dp_keras_model_distributed_test.py"],
tags = [ tags = [
"manual", "manual",

View file

@ -13,16 +13,37 @@
# limitations under the License. # limitations under the License.
"""Keras Model for vectorized dpsgd with XLA acceleration.""" """Keras Model for vectorized dpsgd with XLA acceleration."""
import dataclasses
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
_PRIVATIZED_LOSS_NAME = 'privatized_loss' _PRIVATIZED_LOSS_NAME = 'privatized_loss'
@dataclasses.dataclass
class SparsityPreservingDPSGDConfig:
"""Config for adding sparsity preserving noise to the gradients."""
# The ratio of how the noise is split between partition selection and gradient
# noise.
sparse_selection_ratio: float = 0.0
# The threshold to use for private partition selection.
sparse_selection_threshold: int = 100
# A `LayerRegistry` instance containing functions that help compute
# contribution counts for sparse layers. See
# `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
# more details.
sparse_selection_layer_registry: snlr.LayerRegistry | None = None
def make_dp_model_class(cls): def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
@ -104,6 +125,9 @@ def make_dp_model_class(cls):
num_microbatches=None, num_microbatches=None,
use_xla=True, use_xla=True,
layer_registry=None, layer_registry=None,
sparsity_preserving_dpsgd_config: (
SparsityPreservingDPSGDConfig | None
) = None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs, **kwargs,
): ):
@ -118,6 +142,9 @@ def make_dp_model_class(cls):
help compute gradient norms quickly. See help compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details. more details.
sparsity_preserving_dpsgd_config: If provided, uses partition selection
and sparse noise for privatizing sparse gradients for layers in
`sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
*args: These will be passed on to the base class `__init__` method. *args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method.
""" """
@ -127,6 +154,8 @@ def make_dp_model_class(cls):
self._layer_registry = layer_registry self._layer_registry = layer_registry
self._clipping_loss = None self._clipping_loss = None
self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config
# Given that `num_microbatches` was added as an argument after the fact, # Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API. # this check helps detect unintended calls to the earlier API.
# In particular, boolean values supplied to `use_xla` in the earlier API # In particular, boolean values supplied to `use_xla` in the earlier API
@ -274,27 +303,94 @@ def make_dp_model_class(cls):
# trick, and uses these norms to clip the per-example gradients. # trick, and uses these norms to clip the per-example gradients.
# NOTE: Reshaping of the input according to the effective number of # NOTE: Reshaping of the input according to the effective number of
# microbatches is done here. # microbatches is done here.
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
sparse_noise_layer_registry = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_layer_registry = (
self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry
)
registry_generator_fn = (
gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=self._layer_registry,
sparse_noise_layer_registry=sparse_noise_layer_registry,
num_microbatches=num_microbatches,
)
)
layer_grad_vars, registry_fn_outputs_list = (
gradient_clipping_utils.model_forward_backward_pass(
tape=tape,
input_model=self,
x_batch=x,
y_batch=y,
registry_generator_fn=registry_generator_fn,
weight_batch=weights,
num_microbatches=num_microbatches,
trainable_vars=self.trainable_variables,
)
)
clipped_grads, y_pred, clipping_loss = ( clipped_grads, y_pred, clipping_loss = (
clip_grads.compute_clipped_gradients_and_outputs( clip_grads.compute_clipped_gradients_and_outputs(
input_model=self, input_model=self,
registry_fn_outputs_list=registry_fn_outputs_list,
layer_grad_vars=layer_grad_vars,
x_batch=x, x_batch=x,
y_batch=y, y_batch=y,
weight_batch=weights, weight_batch=weights,
l2_norm_clip=self._l2_norm_clip, l2_norm_clip=self._l2_norm_clip,
layer_registry=self._layer_registry,
num_microbatches=self._num_microbatches, num_microbatches=self._num_microbatches,
clipping_loss=self._clipping_loss, clipping_loss=self._clipping_loss,
) )
) )
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0: noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None
contribution_counts = None
if self._sparsity_preserving_dpsgd_config is not None:
logging.info('Using sparse noise.')
varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
self.trainable_variables,
)
)
contribution_counts = sparse_noise_utils.get_contribution_counts(
self.trainable_variables,
clipped_grads,
varname_to_contribution_counts_fns,
)
noise_multiplier_sparse, noise_multiplier = (
sparse_noise_utils.split_noise_multiplier(
noise_multiplier,
self._sparsity_preserving_dpsgd_config.sparse_selection_ratio,
contribution_counts,
)
)
logging.info(
'Split noise multiplier for gradient noise: %s and partition'
' selection: %s',
noise_multiplier,
noise_multiplier_sparse,
)
if noise_multiplier > 0:
sparse_noise_config = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
sparse_noise_multiplier=noise_multiplier_sparse,
sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold,
sparse_contribution_counts=contribution_counts,
)
grads = noise_utils.add_aggregate_noise( grads = noise_utils.add_aggregate_noise(
clipped_grads, clipped_grads,
num_microbatches, num_microbatches,
self._l2_norm_clip, self._l2_norm_clip,
self._noise_multiplier, noise_multiplier,
loss_reduction=None, loss_reduction=None,
loss_model=self, loss_model=self,
sparse_noise_config=sparse_noise_config,
) )
else: else:
grads = clipped_grads grads = clipped_grads

View file

@ -261,7 +261,8 @@ class AttackInputData:
# Contains ground-truth classes. For single-label classification, classes are # Contains ground-truth classes. For single-label classification, classes are
# assumed to be integers starting from 0. For multi-label classification, # assumed to be integers starting from 0. For multi-label classification,
# label is assumed to be multi-hot, i.e., labels is a binary array of shape # label is assumed to be multi-hot, i.e., labels is a binary array of shape
# (num_examples, num_classes). # (num_examples, num_classes). Additionally used to compute the loss when
# loss_train/test is not provided. Leave empty for non-classification models.
labels_train: Optional[np.ndarray] = None labels_train: Optional[np.ndarray] = None
labels_test: Optional[np.ndarray] = None labels_test: Optional[np.ndarray] = None
@ -270,7 +271,7 @@ class AttackInputData:
sample_weight_test: Optional[np.ndarray] = None sample_weight_test: Optional[np.ndarray] = None
# Explicitly specified loss. If provided, this is used instead of deriving # Explicitly specified loss. If provided, this is used instead of deriving
# loss from logits and labels # loss from logits and labels.
loss_train: Optional[np.ndarray] = None loss_train: Optional[np.ndarray] = None
loss_test: Optional[np.ndarray] = None loss_test: Optional[np.ndarray] = None

View file

@ -5,12 +5,19 @@ licenses(["notice"])
py_library( py_library(
name = "sparse_noise_utils", name = "sparse_noise_utils",
srcs = ["sparse_noise_utils.py"], srcs = ["sparse_noise_utils.py"],
deps = [
":type_aliases",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
) )
py_test( py_test(
name = "sparse_noise_utils_test", name = "sparse_noise_utils_test",
srcs = ["sparse_noise_utils_test.py"], srcs = ["sparse_noise_utils_test.py"],
deps = [":sparse_noise_utils"], deps = [
":sparse_noise_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
) )
py_library( py_library(

View file

@ -16,10 +16,13 @@
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357. For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
""" """
import collections
from typing import Mapping, Optional, Sequence from typing import Mapping, Optional, Sequence
from scipy import stats from scipy import stats
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
import tensorflow_probability as tfp import tensorflow_probability as tfp
@ -163,7 +166,7 @@ def sample_true_positive_indices(
tf.shape(contribution_count_values), tf.shape(contribution_count_values),
mean=0.0, mean=0.0,
stddev=noise_multiplier, stddev=noise_multiplier,
dtype=tf.float32, dtype=contribution_count_values.dtype,
) )
) )
noised_contribution_counts_indices = contribution_counts.indices[ noised_contribution_counts_indices = contribution_counts.indices[
@ -278,7 +281,7 @@ def add_sparse_gradient_noise(
""" """
filtered_grad_values = tf.gather(grad, indices) filtered_grad_values = tf.gather(grad, indices)
sparse_noise_values = tf.random.normal( sparse_noise_values = tf.random.normal(
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev
) )
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
return tf.IndexedSlices( return tf.IndexedSlices(
@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
) )
def extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
trainable_vars: Sequence[tf.Variable],
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
"""Extracts a map of contribution count fns from generator outputs.
Args:
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
instances returned by
`gradient_clipping_utils.model_forward_backward_pass`.
trainable_vars: A list of trainable variables.
Returns:
A `dict` from varname to contribution counts functions
"""
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])
varname_to_contribution_counts_fns = collections.defaultdict(list)
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
if registry_fn_output.varname_to_count_contribution_fn is not None:
duplicate_varnames = set(
registry_fn_output.varname_to_count_contribution_fn.keys()
) & set(varname_to_contribution_counts_fns.keys())
if duplicate_varnames:
raise ValueError(
'Duplicate varnames: {duplicate_varnames} found in contribution'
' counts functions.'
)
varname_to_contribution_counts_fns.update(
registry_fn_output.varname_to_count_contribution_fn
)
return varname_to_contribution_counts_fns
def get_contribution_counts( def get_contribution_counts(
trainable_vars: list[tf.Variable], trainable_vars: Sequence[tf.Variable],
grads: list[tf.Tensor], grads: Sequence[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor], varname_to_contribution_counts_fns: Mapping[
) -> list[tf.Tensor | None]: str, type_aliases.ContributionCountHistogramFn
],
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
"""Gets the contribution counts for each variable in the Model. """Gets the contribution counts for each variable in the Model.
Args: Args:
trainable_vars: A list of the trainable variables in the Model. trainable_vars: A list of trainable variables.
grads: A corresponding list of gradients for each trainable variable. grads: A corresponding list of gradients for each trainable variable.
varname_to_contribution_counts_fns: A mapping from variable name to a list varname_to_contribution_counts_fns: A mapping from variable name to a list
of functions to get the contribution counts for that variable. of functions to get the contribution counts for that variable.
@ -314,15 +362,10 @@ def get_contribution_counts(
if var.name not in varname_to_contribution_counts_fns: if var.name not in varname_to_contribution_counts_fns:
contribution_counts_list.append(None) contribution_counts_list.append(None)
continue continue
contribution_counts_fns = varname_to_contribution_counts_fns[var.name] contribution_counts_fn = varname_to_contribution_counts_fns[var.name]
if not contribution_counts_fns or not contribution_counts_fns[0]: if not contribution_counts_fn:
contribution_counts_list.append(None) contribution_counts_list.append(None)
continue continue
if len(contribution_counts_fns) > 1:
raise NotImplementedError(
'Sparse noise is not supported for shared weight variables.'
)
contribution_counts_fn = contribution_counts_fns[0]
contribution_counts = contribution_counts_fn(grad) contribution_counts = contribution_counts_fn(grad)
contribution_counts_list.append(contribution_counts) contribution_counts_list.append(contribution_counts)

View file

@ -17,6 +17,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
@ -368,7 +369,7 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
tf.ones((1, 2)), tf.ones((1, 2)),
] ]
varname_to_contribution_counts_fns = { varname_to_contribution_counts_fns = {
'var1:0': [lambda grad: 1.0], 'var1:0': lambda grad: 1.0,
'var2:0': None, 'var2:0': None,
} }
contribution_counts = sparse_noise_utils.get_contribution_counts( contribution_counts = sparse_noise_utils.get_contribution_counts(
@ -436,6 +437,96 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy())) np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
) )
def test_extract_varname_to_contribution_counts_fns(self):
def fn1(_):
return 1.0
def fn2(_):
return 2.0
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
var3 = tf.Variable(tf.ones((1, 2)), name='var3')
registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn=None,
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var2:0': [fn2],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer3',
layer_vars=[var3],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var3],
varname_to_count_contribution_fn={
'var3:0': [fn1],
},
),
]
expected_varname_to_contribution_counts_fns = {
'var2:0': [fn2],
'var3:0': [fn1],
}
varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)
)
self.assertEqual(
varname_to_contribution_counts_fns,
expected_varname_to_contribution_counts_fns,
)
def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
def fn1(_):
return 1.0
def fn2(_):
return 2.0
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var1')
registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn={
'var1:0': [fn1],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var1:0': [fn2],
},
),
]
with self.assertRaises(ValueError):
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View file

@ -19,10 +19,10 @@ import tensorflow as tf
InputArgs = Sequence[Any] InputArgs = Sequence[Any]
InputKwargs = Mapping[str, Any] InputKwargs = Mapping[str, Any]
SparseGradient = tf.IndexedSlices SparseGradient = tf.IndexedSlices | tf.SparseTensor
ContributionCountHistogram = tf.SparseTensor ContributionCountHistogram = tf.SparseTensor
ContributionCountHistogramFn = Callable[ ContributionCountHistogramFn = Callable[
[SparseGradient], Mapping[str, ContributionCountHistogram] [SparseGradient], ContributionCountHistogram
] ]
NumMicrobatches = int | tf.Tensor NumMicrobatches = int | tf.Tensor
SparsityPreservingNoiseLayerRegistryFunction = Callable[ SparsityPreservingNoiseLayerRegistryFunction = Callable[