diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py index 91b72da..210ade8 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py @@ -96,7 +96,8 @@ def make_optimizer_class(cls): """Process one microbatch (record) with privacy helper.""" microbatch_loss = tf.reduce_mean( input_tensor=tf.gather(microbatches_losses, [i])) - grads = gradient_tape.gradient(microbatch_loss, var_list) + with gradient_tape.stop_recording(): + grads = gradient_tape.gradient(microbatch_loss, var_list) sample_state = self._dp_sum_query.accumulate_record( sample_params, sample_state, grads) return sample_state