Do not record gradient_tape on gradient calculation.
PiperOrigin-RevId: 308772699
This commit is contained in:
parent
5bc76d2e13
commit
9259ccb3d8
1 changed files with 2 additions and 1 deletions
|
@ -96,7 +96,8 @@ def make_optimizer_class(cls):
|
|||
"""Process one microbatch (record) with privacy helper."""
|
||||
microbatch_loss = tf.reduce_mean(
|
||||
input_tensor=tf.gather(microbatches_losses, [i]))
|
||||
grads = gradient_tape.gradient(microbatch_loss, var_list)
|
||||
with gradient_tape.stop_recording():
|
||||
grads = gradient_tape.gradient(microbatch_loss, var_list)
|
||||
sample_state = self._dp_sum_query.accumulate_record(
|
||||
sample_params, sample_state, grads)
|
||||
return sample_state
|
||||
|
|
Loading…
Reference in a new issue