From c5cb68750777f355be248122c612a2fc23a82022 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Oct 2021 14:26:07 -0700 Subject: [PATCH] Allow using gradient tape for gradient calculation in graph mode. PiperOrigin-RevId: 406217961 --- .../privacy/optimizers/dp_optimizer.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py index 3b80092..1d9c8cc 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py @@ -98,6 +98,7 @@ def make_optimizer_class(cls): dp_sum_query, num_microbatches=None, unroll_microbatches=False, + while_loop_parallel_iterations=10, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args **kwargs): """Initializes the DPOptimizerClass. @@ -111,6 +112,10 @@ def make_optimizer_class(cls): unroll_microbatches: If true, processes microbatches within a Python loop instead of a `tf.while_loop`. Can be used if using a `tf.while_loop` raises an exception. + while_loop_parallel_iterations: The number of iterations allowed to run + in parallel. It must be a positive integer. Applicable only when + unroll_microbatches is set to False. It gives users some control over + memory consumption. *args: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method. """ @@ -122,6 +127,7 @@ def make_optimizer_class(cls): # Beware: When num_microbatches is large (>100), enabling this parameter # may cause an OOM error. self._unroll_microbatches = unroll_microbatches + self._while_loop_parallel_iterations = while_loop_parallel_iterations self._was_compute_gradients_called = False def compute_gradients(self, @@ -177,10 +183,6 @@ def make_optimizer_class(cls): return grads_and_vars else: - # TF is running in graph mode. Check we did not receive a gradient tape. - if gradient_tape: - raise ValueError('When in graph mode, a tape should not be passed.') - # Note: it would be closer to the correct i.i.d. sampling of records if # we sampled each microbatch from the appropriate binomial distribution, # although that still wouldn't be quite correct because it would be @@ -206,10 +208,15 @@ def make_optimizer_class(cls): # This case covers Keras optimizers from optimizers_v2. compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access - grads, _ = zip(*compute_gradients_fn( - mean_loss, var_list, gate_gradients, aggregation_method, - colocate_gradients_with_ops, grad_loss)) - grads_list = list(grads) + if gradient_tape: + # This is intended to work for TF2 and may not work for TF1. + with gradient_tape.stop_recording(): + grads_list = list(gradient_tape.gradient(mean_loss, var_list)) + else: + grads, _ = zip(*compute_gradients_fn( + mean_loss, var_list, gate_gradients, aggregation_method, + colocate_gradients_with_ops, grad_loss)) + grads_list = list(grads) sample_state = self._dp_sum_query.accumulate_record( sample_params, sample_state, grads_list) @@ -233,7 +240,10 @@ def make_optimizer_class(cls): body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long idx = tf.constant(0) _, sample_state = tf.while_loop( - cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state]) + cond=cond_fn, + body=body_fn, + loop_vars=[idx, sample_state], + parallel_iterations=self._while_loop_parallel_iterations) grad_sums, self._global_state, _ = ( self._dp_sum_query.get_noised_result(sample_state,