Sparsity Preserving DP-SGD in TF Privacy

Refactor utilities for adding noise into separate utility file.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660527638
This commit is contained in:
A. Unique TensorFlower 2024-08-07 13:57:56 -07:00
parent fc6f1dc5d1
commit d3f527e775
7 changed files with 203 additions and 158 deletions

View file

@ -54,10 +54,7 @@ py_test(
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
":gradient_clipping_utils",
":type_aliases",
],
deps = [":gradient_clipping_utils"],
)
py_library(
@ -83,6 +80,11 @@ py_library(
],
)
py_library(
name = "noise_utils",
srcs = ["noise_utils.py"],
)
py_test(
name = "clip_grads_test",
srcs = ["clip_grads_test.py"],
@ -96,3 +98,9 @@ py_test(
":type_aliases",
],
)
py_test(
name = "noise_utils_test",
srcs = ["noise_utils_test.py"],
deps = [":noise_utils"],
)

View file

@ -14,9 +14,8 @@
"""Utility functions that help in the computation of per-example gradient norms."""
from collections.abc import Sequence, Set
from typing import Any, Literal, Optional
from typing import Any, Optional
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
@ -144,101 +143,6 @@ def all_trainable_layers_are_registered(
return True
def _infer_loss_reduction_type(model: tf.keras.Model):
"""Infers what type of loss reduction is being performed."""
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return model_loss.reduction
elif isinstance(model.loss, dict):
reductions = set()
compiled_loss = model.compiled_loss
if compiled_loss is None:
raise ValueError('Model must be compiled for adding noise')
new_config_list = compiled_loss.get_config()['losses']
for loss_config in new_config_list:
reductions.add(loss_config['config']['reduction'])
if len(reductions) > 1:
raise ValueError(
'Reductions in models with multiple losses must all be the same'
)
return reductions.pop()
else:
raise ValueError(
'Unsupported type for adding noise: {}'.format(type(model_loss))
)
def add_aggregate_noise(
clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
loss_reduction: An string description of how the loss is reduced over
examples. Currently supports 'mean' and 'sum'. If `None`, then the
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
they are both not `None`.
"""
if loss_reduction is None and loss_model is None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were `None`.'
)
if loss_reduction is not None and loss_model is not None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were not `None`.'
)
if loss_reduction is None and loss_model is not None:
implicit_mean_reductions = [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = _infer_loss_reduction_type(loss_model)
loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum'
)
if model_reduction == tf.keras.losses.Reduction.AUTO:
logging.info(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)
scale = l2_norm_clip
if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32)
def add_noise(g):
return g + tf.random.normal(
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
)
return tf.nest.map_structure(add_noise, clipped_grads)
def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic

View file

@ -15,7 +15,6 @@
from typing import Any
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
@ -135,60 +134,6 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(computed_outputs, true_outputs)
class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
clipped_grads = [
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
]
noised_grads = gradient_clipping_utils.add_aggregate_noise(
clipped_grads,
batch_size,
l2_norm_clip,
noise_multiplier,
noise_fn_reduction,
linear_model,
)
# The only measure that varies is the standard deviation of the variation.
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
class GenerateOutputsUsingCoreKerasLayers(
tf.test.TestCase, parameterized.TestCase
):

View file

@ -0,0 +1,115 @@
# Copyright 2024, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions that help in adding noise to gradients."""
from collections.abc import Sequence
from typing import Literal, Optional
from absl import logging
import tensorflow as tf
def _infer_loss_reduction_type(model: tf.keras.Model):
"""Infers what type of loss reduction is being performed."""
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return model_loss.reduction
elif isinstance(model.loss, dict):
reductions = set()
compiled_loss = model.compiled_loss
if compiled_loss is None:
raise ValueError('Model must be compiled for adding noise')
new_config_list = compiled_loss.get_config()['losses']
for loss_config in new_config_list:
reductions.add(loss_config['config']['reduction'])
if len(reductions) > 1:
raise ValueError(
'Reductions in models with multiple losses must all be the same'
)
return reductions.pop()
else:
raise ValueError(
'Unsupported type for adding noise: {}'.format(type(model_loss))
)
def add_aggregate_noise(
clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
loss_reduction: An string description of how the loss is reduced over
examples. Currently supports 'mean' and 'sum'. If `None`, then the
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
they are both not `None`.
"""
if loss_reduction is None and loss_model is None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were `None`.'
)
if loss_reduction is not None and loss_model is not None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were not `None`.'
)
if loss_reduction is None and loss_model is not None:
implicit_mean_reductions = [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = _infer_loss_reduction_type(loss_model)
loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum'
)
if model_reduction == tf.keras.losses.Reduction.AUTO:
logging.info(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)
scale = l2_norm_clip
if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32)
def add_noise(g):
return g + tf.random.normal(
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
)
return tf.nest.map_structure(add_noise, clipped_grads)

View file

@ -0,0 +1,72 @@
# Copyright 2024, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
clipped_grads = [
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
]
noised_grads = noise_utils.add_aggregate_noise(
clipped_grads,
batch_size,
l2_norm_clip,
noise_multiplier,
noise_fn_reduction,
linear_model,
)
# The only measure that varies is the standard deviation of the variation.
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std)

View file

@ -17,7 +17,7 @@ py_library(
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
],
)

View file

@ -18,6 +18,7 @@ import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
@ -287,7 +288,7 @@ def make_dp_model_class(cls):
)
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0:
grads = gradient_clipping_utils.add_aggregate_noise(
grads = noise_utils.add_aggregate_noise(
clipped_grads,
num_microbatches,
self._l2_norm_clip,