forked from 626_privacy/tensorflow_privacy
missing reduce_mean
PiperOrigin-RevId: 235858614
This commit is contained in:
parent
df6f065925
commit
0c691085e1
1 changed files with 1 additions and 1 deletions
|
@ -75,7 +75,7 @@ def make_optimizer_class(cls):
|
|||
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
microbatch_loss = tf.gather(microbatches_losses, [i])
|
||||
microbatch_loss = tf.reduce_mean(tf.gather(microbatches_losses, [i]))
|
||||
grads = gradient_tape.gradient(microbatch_loss, var_list)
|
||||
sample_state = self._dp_average_query.accumulate_record(sample_params,
|
||||
sample_state,
|
||||
|
|
Loading…
Reference in a new issue