Skips noise addition when noise_multiplier is 0. Fix a typo.

PiperOrigin-RevId: 521912964
This commit is contained in:
Shuang Song 2023-04-04 17:47:55 -07:00 committed by A. Unique TensorFlower
parent ee1abe6930
commit de9836883d

View file

@ -165,13 +165,14 @@ def make_dp_model_class(cls):
def _reduce_per_example_grads(self, stacked_grads): def _reduce_per_example_grads(self, stacked_grads):
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0) summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
noise_stddev = self._l2_norm_clip * self._noise_multiplier if self._noise_multiplier > 0:
noise = tf.random.normal( noise_stddev = self._l2_norm_clip * self._noise_multiplier
tf.shape(input=summed_grads), stddev=noise_stddev noise = tf.random.normal(
) tf.shape(input=summed_grads), stddev=noise_stddev
noised_grads = summed_grads + noise )
return noised_grads / tf.cast( summed_grads = summed_grads + noise
tf.shape(stacked_grads)[0], noised_grads.dtype return summed_grads / tf.cast(
tf.shape(stacked_grads)[0], summed_grads.dtype
) )
def _compute_per_example_grads(self, data): def _compute_per_example_grads(self, data):
@ -207,7 +208,7 @@ def make_dp_model_class(cls):
output_metrics = {} output_metrics = {}
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
batch_size = tf.shape(y)[0] batch_size = tf.shape(y)[0]
eff_microbatch_size = self._num_microbatches or batch_size eff_num_microbatches = self._num_microbatches or batch_size
privatized_loss_name = 'privatized_loss' privatized_loss_name = 'privatized_loss'
# Branch based on gradient clipping algorithm. # Branch based on gradient clipping algorithm.
@ -233,7 +234,7 @@ def make_dp_model_class(cls):
grads = gradient_clipping_utils.add_aggregate_noise( grads = gradient_clipping_utils.add_aggregate_noise(
self, self,
clipped_grads, clipped_grads,
eff_microbatch_size, eff_num_microbatches,
self._l2_norm_clip, self._l2_norm_clip,
self._noise_multiplier, self._noise_multiplier,
) )
@ -243,7 +244,7 @@ def make_dp_model_class(cls):
# Computes per-example clipped gradients directly. This is called # Computes per-example clipped gradients directly. This is called
# if at least one of the layers cannot use the "fast" gradient clipping # if at least one of the layers cannot use the "fast" gradient clipping
# algorithm. # algorithm.
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_microbatch_size) reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches)
microbatched_data = tf.nest.map_structure(reshape_fn, data) microbatched_data = tf.nest.map_structure(reshape_fn, data)
microbatched_losses, clipped_grads = tf.vectorized_map( microbatched_losses, clipped_grads = tf.vectorized_map(
self._compute_per_example_grads, self._compute_per_example_grads,