missing reduce_mean

PiperOrigin-RevId: 235858614
This commit is contained in:
Nicolas Papernot 2019-02-26 22:56:10 -08:00 committed by A. Unique TensorFlower
parent df6f065925
commit 0c691085e1

View file

@ -75,7 +75,7 @@ def make_optimizer_class(cls):
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
microbatch_loss = tf.gather(microbatches_losses, [i])
microbatch_loss = tf.reduce_mean(tf.gather(microbatches_losses, [i]))
grads = gradient_tape.gradient(microbatch_loss, var_list)
sample_state = self._dp_average_query.accumulate_record(sample_params,
sample_state,