Skips noise addition when noise_multiplier is 0. Fix a typo.

PiperOrigin-RevId: 521912964
This commit is contained in:
Shuang Song 2023-04-04 17:47:55 -07:00 committed by A. Unique TensorFlower
parent ee1abe6930
commit de9836883d

View file

@ -165,13 +165,14 @@ def make_dp_model_class(cls):
def _reduce_per_example_grads(self, stacked_grads):
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
noise_stddev = self._l2_norm_clip * self._noise_multiplier
noise = tf.random.normal(
tf.shape(input=summed_grads), stddev=noise_stddev
)
noised_grads = summed_grads + noise
return noised_grads / tf.cast(
tf.shape(stacked_grads)[0], noised_grads.dtype
if self._noise_multiplier > 0:
noise_stddev = self._l2_norm_clip * self._noise_multiplier
noise = tf.random.normal(
tf.shape(input=summed_grads), stddev=noise_stddev
)
summed_grads = summed_grads + noise
return summed_grads / tf.cast(
tf.shape(stacked_grads)[0], summed_grads.dtype
)
def _compute_per_example_grads(self, data):
@ -207,7 +208,7 @@ def make_dp_model_class(cls):
output_metrics = {}
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
batch_size = tf.shape(y)[0]
eff_microbatch_size = self._num_microbatches or batch_size
eff_num_microbatches = self._num_microbatches or batch_size
privatized_loss_name = 'privatized_loss'
# Branch based on gradient clipping algorithm.
@ -233,7 +234,7 @@ def make_dp_model_class(cls):
grads = gradient_clipping_utils.add_aggregate_noise(
self,
clipped_grads,
eff_microbatch_size,
eff_num_microbatches,
self._l2_norm_clip,
self._noise_multiplier,
)
@ -243,7 +244,7 @@ def make_dp_model_class(cls):
# Computes per-example clipped gradients directly. This is called
# if at least one of the layers cannot use the "fast" gradient clipping
# algorithm.
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_microbatch_size)
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches)
microbatched_data = tf.nest.map_structure(reshape_fn, data)
microbatched_losses, clipped_grads = tf.vectorized_map(
self._compute_per_example_grads,