Do not record gradient_tape on gradient calculation.
PiperOrigin-RevId: 308772699
This commit is contained in:
parent
5bc76d2e13
commit
9259ccb3d8
1 changed files with 2 additions and 1 deletions
|
@ -96,7 +96,8 @@ def make_optimizer_class(cls):
|
||||||
"""Process one microbatch (record) with privacy helper."""
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
microbatch_loss = tf.reduce_mean(
|
microbatch_loss = tf.reduce_mean(
|
||||||
input_tensor=tf.gather(microbatches_losses, [i]))
|
input_tensor=tf.gather(microbatches_losses, [i]))
|
||||||
grads = gradient_tape.gradient(microbatch_loss, var_list)
|
with gradient_tape.stop_recording():
|
||||||
|
grads = gradient_tape.gradient(microbatch_loss, var_list)
|
||||||
sample_state = self._dp_sum_query.accumulate_record(
|
sample_state = self._dp_sum_query.accumulate_record(
|
||||||
sample_params, sample_state, grads)
|
sample_params, sample_state, grads)
|
||||||
return sample_state
|
return sample_state
|
||||||
|
|
Loading…
Reference in a new issue