diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 82112fa..b5ce15d 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -75,7 +75,7 @@ def make_optimizer_class(cls): def process_microbatch(i, sample_state): """Process one microbatch (record) with privacy helper.""" - microbatch_loss = tf.gather(microbatches_losses, [i]) + microbatch_loss = tf.reduce_mean(tf.gather(microbatches_losses, [i])) grads = gradient_tape.gradient(microbatch_loss, var_list) sample_state = self._dp_average_query.accumulate_record(sample_params, sample_state,