Sparsity Preserving DP-SGD in TF Privacy
Refactor utilities for adding noise into separate utility file. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660527638
This commit is contained in:
parent
fc6f1dc5d1
commit
d3f527e775
7 changed files with 203 additions and 158 deletions
|
@ -54,10 +54,7 @@ py_test(
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 8,
|
shard_count = 8,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [":gradient_clipping_utils"],
|
||||||
":gradient_clipping_utils",
|
|
||||||
":type_aliases",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -83,6 +80,11 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "noise_utils",
|
||||||
|
srcs = ["noise_utils.py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "clip_grads_test",
|
name = "clip_grads_test",
|
||||||
srcs = ["clip_grads_test.py"],
|
srcs = ["clip_grads_test.py"],
|
||||||
|
@ -96,3 +98,9 @@ py_test(
|
||||||
":type_aliases",
|
":type_aliases",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "noise_utils_test",
|
||||||
|
srcs = ["noise_utils_test.py"],
|
||||||
|
deps = [":noise_utils"],
|
||||||
|
)
|
||||||
|
|
|
@ -14,9 +14,8 @@
|
||||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||||
|
|
||||||
from collections.abc import Sequence, Set
|
from collections.abc import Sequence, Set
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from absl import logging
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
@ -144,101 +143,6 @@ def all_trainable_layers_are_registered(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _infer_loss_reduction_type(model: tf.keras.Model):
|
|
||||||
"""Infers what type of loss reduction is being performed."""
|
|
||||||
model_loss = model.loss
|
|
||||||
if isinstance(model_loss, tf.keras.losses.Loss):
|
|
||||||
return model_loss.reduction
|
|
||||||
elif isinstance(model.loss, dict):
|
|
||||||
reductions = set()
|
|
||||||
compiled_loss = model.compiled_loss
|
|
||||||
if compiled_loss is None:
|
|
||||||
raise ValueError('Model must be compiled for adding noise')
|
|
||||||
new_config_list = compiled_loss.get_config()['losses']
|
|
||||||
for loss_config in new_config_list:
|
|
||||||
reductions.add(loss_config['config']['reduction'])
|
|
||||||
if len(reductions) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
'Reductions in models with multiple losses must all be the same'
|
|
||||||
)
|
|
||||||
return reductions.pop()
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'Unsupported type for adding noise: {}'.format(type(model_loss))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_aggregate_noise(
|
|
||||||
clipped_grads: list[tf.Tensor],
|
|
||||||
batch_size: tf.Tensor,
|
|
||||||
l2_norm_clip: float,
|
|
||||||
noise_multiplier: float,
|
|
||||||
loss_reduction: Optional[Literal['mean', 'sum']] = None,
|
|
||||||
loss_model: Optional[tf.keras.Model] = None,
|
|
||||||
) -> Sequence[tf.Tensor]:
|
|
||||||
"""Adds noise to a collection of clipped gradients.
|
|
||||||
|
|
||||||
The magnitude of the noise depends on the aggregation strategy of the
|
|
||||||
input model's loss function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
|
||||||
batch_size: The batch size. Used for normalizing the noise when
|
|
||||||
`loss_reduction` is 'sum'.
|
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
|
||||||
loss_reduction: An string description of how the loss is reduced over
|
|
||||||
examples. Currently supports 'mean' and 'sum'. If `None`, then the
|
|
||||||
aggregation type must be inferred from `input_model.loss`.
|
|
||||||
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
|
|
||||||
strategy from if `loss_reduction` is `None`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of tensors containing the clipped gradients, but with the right
|
|
||||||
amount of Gaussian noise added to them (depending on the reduction
|
|
||||||
strategy of the loss function).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
|
|
||||||
they are both not `None`.
|
|
||||||
"""
|
|
||||||
if loss_reduction is None and loss_model is None:
|
|
||||||
raise ValueError(
|
|
||||||
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
|
|
||||||
' Instead, both arguments were `None`.'
|
|
||||||
)
|
|
||||||
if loss_reduction is not None and loss_model is not None:
|
|
||||||
raise ValueError(
|
|
||||||
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
|
|
||||||
' Instead, both arguments were not `None`.'
|
|
||||||
)
|
|
||||||
|
|
||||||
if loss_reduction is None and loss_model is not None:
|
|
||||||
implicit_mean_reductions = [
|
|
||||||
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
|
||||||
tf.keras.losses.Reduction.AUTO,
|
|
||||||
]
|
|
||||||
model_reduction = _infer_loss_reduction_type(loss_model)
|
|
||||||
loss_reduction = (
|
|
||||||
'mean' if model_reduction in implicit_mean_reductions else 'sum'
|
|
||||||
)
|
|
||||||
if model_reduction == tf.keras.losses.Reduction.AUTO:
|
|
||||||
logging.info(
|
|
||||||
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
|
|
||||||
)
|
|
||||||
|
|
||||||
scale = l2_norm_clip
|
|
||||||
if loss_reduction == 'mean':
|
|
||||||
scale /= tf.cast(batch_size, tf.float32)
|
|
||||||
|
|
||||||
def add_noise(g):
|
|
||||||
return g + tf.random.normal(
|
|
||||||
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
|
|
||||||
)
|
|
||||||
|
|
||||||
return tf.nest.map_structure(add_noise, clipped_grads)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_model_outputs_using_core_keras_layers(
|
def generate_model_outputs_using_core_keras_layers(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic
|
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
|
||||||
|
@ -135,60 +134,6 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllClose(computed_outputs, true_outputs)
|
self.assertAllClose(computed_outputs, true_outputs)
|
||||||
|
|
||||||
|
|
||||||
class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase):
|
|
||||||
|
|
||||||
@parameterized.product(
|
|
||||||
l2_norm_clip=[3.0, 5.0],
|
|
||||||
noise_multiplier=[2.0, 4.0],
|
|
||||||
batch_size=[1, 2, 10],
|
|
||||||
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
|
|
||||||
noise_fn_reduction=[None, 'mean', 'sum'],
|
|
||||||
)
|
|
||||||
def test_noise_is_computed_correctly(
|
|
||||||
self,
|
|
||||||
l2_norm_clip,
|
|
||||||
noise_multiplier,
|
|
||||||
batch_size,
|
|
||||||
model_fn_reduction,
|
|
||||||
noise_fn_reduction,
|
|
||||||
):
|
|
||||||
# Skip invalid combinations.
|
|
||||||
if model_fn_reduction is None and noise_fn_reduction is None:
|
|
||||||
return
|
|
||||||
if model_fn_reduction is not None and noise_fn_reduction is not None:
|
|
||||||
return
|
|
||||||
# Make an simple model container for storing the loss.
|
|
||||||
if model_fn_reduction is not None:
|
|
||||||
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
|
|
||||||
linear_model.compile(
|
|
||||||
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
linear_model = None
|
|
||||||
# The main computation is done on a deterministic dummy vector.
|
|
||||||
num_units = 100
|
|
||||||
clipped_grads = [
|
|
||||||
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
|
|
||||||
]
|
|
||||||
noised_grads = gradient_clipping_utils.add_aggregate_noise(
|
|
||||||
clipped_grads,
|
|
||||||
batch_size,
|
|
||||||
l2_norm_clip,
|
|
||||||
noise_multiplier,
|
|
||||||
noise_fn_reduction,
|
|
||||||
linear_model,
|
|
||||||
)
|
|
||||||
# The only measure that varies is the standard deviation of the variation.
|
|
||||||
scale = (
|
|
||||||
1.0
|
|
||||||
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
|
|
||||||
else 1.0 / batch_size
|
|
||||||
)
|
|
||||||
computed_std = np.std(noised_grads[0] - clipped_grads[0])
|
|
||||||
expected_std = l2_norm_clip * noise_multiplier * scale
|
|
||||||
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateOutputsUsingCoreKerasLayers(
|
class GenerateOutputsUsingCoreKerasLayers(
|
||||||
tf.test.TestCase, parameterized.TestCase
|
tf.test.TestCase, parameterized.TestCase
|
||||||
):
|
):
|
||||||
|
|
115
tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py
Normal file
115
tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright 2024, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utility functions that help in adding noise to gradients."""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_loss_reduction_type(model: tf.keras.Model):
|
||||||
|
"""Infers what type of loss reduction is being performed."""
|
||||||
|
model_loss = model.loss
|
||||||
|
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||||
|
return model_loss.reduction
|
||||||
|
elif isinstance(model.loss, dict):
|
||||||
|
reductions = set()
|
||||||
|
compiled_loss = model.compiled_loss
|
||||||
|
if compiled_loss is None:
|
||||||
|
raise ValueError('Model must be compiled for adding noise')
|
||||||
|
new_config_list = compiled_loss.get_config()['losses']
|
||||||
|
for loss_config in new_config_list:
|
||||||
|
reductions.add(loss_config['config']['reduction'])
|
||||||
|
if len(reductions) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'Reductions in models with multiple losses must all be the same'
|
||||||
|
)
|
||||||
|
return reductions.pop()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Unsupported type for adding noise: {}'.format(type(model_loss))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_aggregate_noise(
|
||||||
|
clipped_grads: list[tf.Tensor],
|
||||||
|
batch_size: tf.Tensor,
|
||||||
|
l2_norm_clip: float,
|
||||||
|
noise_multiplier: float,
|
||||||
|
loss_reduction: Optional[Literal['mean', 'sum']] = None,
|
||||||
|
loss_model: Optional[tf.keras.Model] = None,
|
||||||
|
) -> Sequence[tf.Tensor]:
|
||||||
|
"""Adds noise to a collection of clipped gradients.
|
||||||
|
|
||||||
|
The magnitude of the noise depends on the aggregation strategy of the
|
||||||
|
input model's loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
||||||
|
batch_size: The batch size. Used for normalizing the noise when
|
||||||
|
`loss_reduction` is 'sum'.
|
||||||
|
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||||
|
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||||
|
loss_reduction: An string description of how the loss is reduced over
|
||||||
|
examples. Currently supports 'mean' and 'sum'. If `None`, then the
|
||||||
|
aggregation type must be inferred from `input_model.loss`.
|
||||||
|
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
|
||||||
|
strategy from if `loss_reduction` is `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tensors containing the clipped gradients, but with the right
|
||||||
|
amount of Gaussian noise added to them (depending on the reduction
|
||||||
|
strategy of the loss function).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
|
||||||
|
they are both not `None`.
|
||||||
|
"""
|
||||||
|
if loss_reduction is None and loss_model is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
|
||||||
|
' Instead, both arguments were `None`.'
|
||||||
|
)
|
||||||
|
if loss_reduction is not None and loss_model is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
|
||||||
|
' Instead, both arguments were not `None`.'
|
||||||
|
)
|
||||||
|
|
||||||
|
if loss_reduction is None and loss_model is not None:
|
||||||
|
implicit_mean_reductions = [
|
||||||
|
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||||
|
tf.keras.losses.Reduction.AUTO,
|
||||||
|
]
|
||||||
|
model_reduction = _infer_loss_reduction_type(loss_model)
|
||||||
|
loss_reduction = (
|
||||||
|
'mean' if model_reduction in implicit_mean_reductions else 'sum'
|
||||||
|
)
|
||||||
|
if model_reduction == tf.keras.losses.Reduction.AUTO:
|
||||||
|
logging.info(
|
||||||
|
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = l2_norm_clip
|
||||||
|
if loss_reduction == 'mean':
|
||||||
|
scale /= tf.cast(batch_size, tf.float32)
|
||||||
|
|
||||||
|
def add_noise(g):
|
||||||
|
return g + tf.random.normal(
|
||||||
|
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
|
||||||
|
)
|
||||||
|
|
||||||
|
return tf.nest.map_structure(add_noise, clipped_grads)
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright 2024, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
l2_norm_clip=[3.0, 5.0],
|
||||||
|
noise_multiplier=[2.0, 4.0],
|
||||||
|
batch_size=[1, 2, 10],
|
||||||
|
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
|
||||||
|
noise_fn_reduction=[None, 'mean', 'sum'],
|
||||||
|
)
|
||||||
|
def test_noise_is_computed_correctly(
|
||||||
|
self,
|
||||||
|
l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
batch_size,
|
||||||
|
model_fn_reduction,
|
||||||
|
noise_fn_reduction,
|
||||||
|
):
|
||||||
|
# Skip invalid combinations.
|
||||||
|
if model_fn_reduction is None and noise_fn_reduction is None:
|
||||||
|
return
|
||||||
|
if model_fn_reduction is not None and noise_fn_reduction is not None:
|
||||||
|
return
|
||||||
|
# Make an simple model container for storing the loss.
|
||||||
|
if model_fn_reduction is not None:
|
||||||
|
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
|
||||||
|
linear_model.compile(
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
linear_model = None
|
||||||
|
# The main computation is done on a deterministic dummy vector.
|
||||||
|
num_units = 100
|
||||||
|
clipped_grads = [
|
||||||
|
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
|
||||||
|
]
|
||||||
|
noised_grads = noise_utils.add_aggregate_noise(
|
||||||
|
clipped_grads,
|
||||||
|
batch_size,
|
||||||
|
l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
noise_fn_reduction,
|
||||||
|
linear_model,
|
||||||
|
)
|
||||||
|
# The only measure that varies is the standard deviation of the variation.
|
||||||
|
scale = (
|
||||||
|
1.0
|
||||||
|
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
|
||||||
|
else 1.0 / batch_size
|
||||||
|
)
|
||||||
|
computed_std = np.std(noised_grads[0] - clipped_grads[0])
|
||||||
|
expected_std = l2_norm_clip * noise_multiplier * scale
|
||||||
|
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
|
|
@ -17,7 +17,7 @@ py_library(
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
|
||||||
|
|
||||||
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
|
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
|
||||||
|
|
||||||
|
@ -287,7 +288,7 @@ def make_dp_model_class(cls):
|
||||||
)
|
)
|
||||||
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
||||||
if self._noise_multiplier > 0:
|
if self._noise_multiplier > 0:
|
||||||
grads = gradient_clipping_utils.add_aggregate_noise(
|
grads = noise_utils.add_aggregate_noise(
|
||||||
clipped_grads,
|
clipped_grads,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
self._l2_norm_clip,
|
self._l2_norm_clip,
|
||||||
|
|
Loading…
Reference in a new issue