forked from 626_privacy/tensorflow_privacy
Add a parameter to the noise function that explicitly specifies the loss reduction type.
PiperOrigin-RevId: 583507445
This commit is contained in:
parent
39c8a8c1af
commit
03db50ba94
5 changed files with 99 additions and 12 deletions
|
@ -38,6 +38,7 @@ py_test(
|
|||
name = "gradient_clipping_utils_test",
|
||||
srcs = ["gradient_clipping_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":gradient_clipping_utils",
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
from collections.abc import Sequence, Set
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
@ -145,11 +145,12 @@ def all_trainable_layers_are_registered(
|
|||
|
||||
|
||||
def add_aggregate_noise(
|
||||
input_model: tf.keras.Model,
|
||||
clipped_grads: list[tf.Tensor],
|
||||
batch_size: tf.Tensor,
|
||||
l2_norm_clip: float,
|
||||
noise_multiplier: float,
|
||||
loss_reduction: Optional[Literal['mean', 'sum']] = None,
|
||||
loss_model: Optional[tf.keras.Model] = None,
|
||||
) -> Sequence[tf.Tensor]:
|
||||
"""Adds noise to a collection of clipped gradients.
|
||||
|
||||
|
@ -157,25 +158,53 @@ def add_aggregate_noise(
|
|||
input model's loss function.
|
||||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` to obtain the layers from.
|
||||
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
||||
batch_size: The batch size, used for normalizing the noise, when the loss
|
||||
reduction is AUTO or SUM_OVER_BATCH_SIZE.
|
||||
batch_size: The batch size. Used for normalizing the noise when
|
||||
`loss_reduction` is 'sum'.
|
||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||
loss_reduction: An string description of how the loss is reduced over
|
||||
examples. Currently supports 'mean' and 'sum'. If `None`, then the
|
||||
aggregation type must be inferred from `input_model.loss`.
|
||||
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
|
||||
strategy from if `loss_reduction` is `None`.
|
||||
|
||||
Returns:
|
||||
A list of tensors containing the clipped gradients, but with the right
|
||||
amount of Gaussian noise added to them (depending on the reduction
|
||||
strategy of the loss function).
|
||||
|
||||
Raises:
|
||||
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
|
||||
they are both not `None`.
|
||||
"""
|
||||
if loss_reduction is None and loss_model is None:
|
||||
raise ValueError(
|
||||
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
|
||||
' Instead, both arguments were `None`.'
|
||||
)
|
||||
if loss_reduction is not None and loss_model is not None:
|
||||
raise ValueError(
|
||||
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
|
||||
' Instead, both arguments were not `None`.'
|
||||
)
|
||||
|
||||
if loss_reduction is None and loss_model is not None:
|
||||
implicit_mean_reductions = [
|
||||
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
tf.keras.losses.Reduction.AUTO,
|
||||
]
|
||||
model_reduction = loss_model.loss.reduction
|
||||
loss_reduction = (
|
||||
'mean' if model_reduction in implicit_mean_reductions else 'sum'
|
||||
)
|
||||
if model_reduction == tf.keras.losses.Reduction.AUTO:
|
||||
logging.info(
|
||||
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
|
||||
)
|
||||
|
||||
scale = l2_norm_clip
|
||||
if input_model.loss.reduction in [
|
||||
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
tf.keras.losses.Reduction.AUTO,
|
||||
]:
|
||||
if input_model.loss.reduction == tf.keras.losses.Reduction.AUTO:
|
||||
logging.info('Assuming that the loss reduction is `SUM_OVER_BATCH_SIZE`.')
|
||||
if loss_reduction == 'mean':
|
||||
scale /= tf.cast(batch_size, tf.float32)
|
||||
|
||||
def add_noise(g):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
from typing import Any
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
||||
|
@ -134,6 +135,60 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertAllClose(computed_outputs, true_outputs)
|
||||
|
||||
|
||||
class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
l2_norm_clip=[3.0, 5.0],
|
||||
noise_multiplier=[2.0, 4.0],
|
||||
batch_size=[1, 2, 10],
|
||||
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
|
||||
noise_fn_reduction=[None, 'mean', 'sum'],
|
||||
)
|
||||
def test_noise_is_computed_correctly(
|
||||
self,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
batch_size,
|
||||
model_fn_reduction,
|
||||
noise_fn_reduction,
|
||||
):
|
||||
# Skip invalid combinations.
|
||||
if model_fn_reduction is None and noise_fn_reduction is None:
|
||||
return
|
||||
if model_fn_reduction is not None and noise_fn_reduction is not None:
|
||||
return
|
||||
# Make an simple model container for storing the loss.
|
||||
if model_fn_reduction is not None:
|
||||
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
|
||||
linear_model.compile(
|
||||
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
|
||||
)
|
||||
else:
|
||||
linear_model = None
|
||||
# The main computation is done on a deterministic dummy vector.
|
||||
num_units = 100
|
||||
clipped_grads = [
|
||||
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
|
||||
]
|
||||
noised_grads = gradient_clipping_utils.add_aggregate_noise(
|
||||
clipped_grads,
|
||||
batch_size,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
noise_fn_reduction,
|
||||
linear_model,
|
||||
)
|
||||
# The only measure that varies is the standard deviation of the variation.
|
||||
scale = (
|
||||
1.0
|
||||
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
|
||||
else 1.0 / batch_size
|
||||
)
|
||||
computed_std = np.std(noised_grads[0] - clipped_grads[0])
|
||||
expected_std = l2_norm_clip * noise_multiplier * scale
|
||||
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
|
||||
|
||||
|
||||
class GenerateOutputsUsingCoreKerasLayers(
|
||||
tf.test.TestCase, parameterized.TestCase
|
||||
):
|
||||
|
|
|
@ -25,6 +25,7 @@ py_test(
|
|||
name = "dp_keras_model_test",
|
||||
srcs = ["dp_keras_model_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 16,
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||
|
|
|
@ -264,11 +264,12 @@ def make_dp_model_class(cls):
|
|||
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
||||
if self._noise_multiplier > 0:
|
||||
grads = gradient_clipping_utils.add_aggregate_noise(
|
||||
self,
|
||||
clipped_grads,
|
||||
num_microbatches,
|
||||
self._l2_norm_clip,
|
||||
self._noise_multiplier,
|
||||
loss_reduction=None,
|
||||
loss_model=self,
|
||||
)
|
||||
else:
|
||||
grads = clipped_grads
|
||||
|
|
Loading…
Reference in a new issue