Add a parameter to the noise function that explicitly specifies the loss reduction type.

PiperOrigin-RevId: 583507445
This commit is contained in:
William Kong 2023-11-17 15:54:20 -08:00 committed by A. Unique TensorFlower
parent 39c8a8c1af
commit 03db50ba94
5 changed files with 99 additions and 12 deletions

View file

@ -38,6 +38,7 @@ py_test(
name = "gradient_clipping_utils_test", name = "gradient_clipping_utils_test",
srcs = ["gradient_clipping_utils_test.py"], srcs = ["gradient_clipping_utils_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 8,
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":gradient_clipping_utils", ":gradient_clipping_utils",

View file

@ -14,7 +14,7 @@
"""Utility functions that help in the computation of per-example gradient norms.""" """Utility functions that help in the computation of per-example gradient norms."""
from collections.abc import Sequence, Set from collections.abc import Sequence, Set
from typing import Any, Optional from typing import Any, Literal, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
@ -145,11 +145,12 @@ def all_trainable_layers_are_registered(
def add_aggregate_noise( def add_aggregate_noise(
input_model: tf.keras.Model,
clipped_grads: list[tf.Tensor], clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor, batch_size: tf.Tensor,
l2_norm_clip: float, l2_norm_clip: float,
noise_multiplier: float, noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]: ) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients. """Adds noise to a collection of clipped gradients.
@ -157,25 +158,53 @@ def add_aggregate_noise(
input model's loss function. input model's loss function.
Args: Args:
input_model: The `tf.keras.Model` to obtain the layers from.
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size, used for normalizing the noise, when the loss batch_size: The batch size. Used for normalizing the noise when
reduction is AUTO or SUM_OVER_BATCH_SIZE. `loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient). l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm. noise_multiplier: Ratio of the standard deviation to the clipping norm.
loss_reduction: An string description of how the loss is reduced over
examples. Currently supports 'mean' and 'sum'. If `None`, then the
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
Returns: Returns:
A list of tensors containing the clipped gradients, but with the right A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function). strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
they are both not `None`.
""" """
if loss_reduction is None and loss_model is None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were `None`.'
)
if loss_reduction is not None and loss_model is not None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were not `None`.'
)
if loss_reduction is None and loss_model is not None:
implicit_mean_reductions = [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = loss_model.loss.reduction
loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum'
)
if model_reduction == tf.keras.losses.Reduction.AUTO:
logging.info(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)
scale = l2_norm_clip scale = l2_norm_clip
if input_model.loss.reduction in [ if loss_reduction == 'mean':
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]:
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')
scale /= tf.cast(batch_size, tf.float32) scale /= tf.cast(batch_size, tf.float32)
def add_noise(g): def add_noise(g):

View file

@ -15,6 +15,7 @@
from typing import Any from typing import Any
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
@ -134,6 +135,60 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(computed_outputs, true_outputs) self.assertAllClose(computed_outputs, true_outputs)
class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
clipped_grads = [
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
]
noised_grads = gradient_clipping_utils.add_aggregate_noise(
clipped_grads,
batch_size,
l2_norm_clip,
noise_multiplier,
noise_fn_reduction,
linear_model,
)
# The only measure that varies is the standard deviation of the variation.
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
class GenerateOutputsUsingCoreKerasLayers( class GenerateOutputsUsingCoreKerasLayers(
tf.test.TestCase, parameterized.TestCase tf.test.TestCase, parameterized.TestCase
): ):

View file

@ -25,6 +25,7 @@ py_test(
name = "dp_keras_model_test", name = "dp_keras_model_test",
srcs = ["dp_keras_model_test.py"], srcs = ["dp_keras_model_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 16,
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",

View file

@ -264,11 +264,12 @@ def make_dp_model_class(cls):
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0: if self._noise_multiplier > 0:
grads = gradient_clipping_utils.add_aggregate_noise( grads = gradient_clipping_utils.add_aggregate_noise(
self,
clipped_grads, clipped_grads,
num_microbatches, num_microbatches,
self._l2_norm_clip, self._l2_norm_clip,
self._noise_multiplier, self._noise_multiplier,
loss_reduction=None,
loss_model=self,
) )
else: else:
grads = clipped_grads grads = clipped_grads