Do not record gradient_tape on gradient calculation.

PiperOrigin-RevId: 308772699
This commit is contained in:
A. Unique TensorFlower 2020-04-27 23:56:58 -07:00
parent 5bc76d2e13
commit 9259ccb3d8

View file

@ -96,7 +96,8 @@ def make_optimizer_class(cls):
"""Process one microbatch (record) with privacy helper."""
microbatch_loss = tf.reduce_mean(
input_tensor=tf.gather(microbatches_losses, [i]))
grads = gradient_tape.gradient(microbatch_loss, var_list)
with gradient_tape.stop_recording():
grads = gradient_tape.gradient(microbatch_loss, var_list)
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads)
return sample_state