Allow using gradient tape for gradient calculation in graph mode.

PiperOrigin-RevId: 406217961
This commit is contained in:
A. Unique TensorFlower 2021-10-28 14:26:07 -07:00
parent c530356ae9
commit c5cb687507

View file

@ -98,6 +98,7 @@ def make_optimizer_class(cls):
dp_sum_query, dp_sum_query,
num_microbatches=None, num_microbatches=None,
unroll_microbatches=False, unroll_microbatches=False,
while_loop_parallel_iterations=10,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs): **kwargs):
"""Initializes the DPOptimizerClass. """Initializes the DPOptimizerClass.
@ -111,6 +112,10 @@ def make_optimizer_class(cls):
unroll_microbatches: If true, processes microbatches within a Python unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception. `tf.while_loop` raises an exception.
while_loop_parallel_iterations: The number of iterations allowed to run
in parallel. It must be a positive integer. Applicable only when
unroll_microbatches is set to False. It gives users some control over
memory consumption.
*args: These will be passed on to the base class `__init__` method. *args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method.
""" """
@ -122,6 +127,7 @@ def make_optimizer_class(cls):
# Beware: When num_microbatches is large (>100), enabling this parameter # Beware: When num_microbatches is large (>100), enabling this parameter
# may cause an OOM error. # may cause an OOM error.
self._unroll_microbatches = unroll_microbatches self._unroll_microbatches = unroll_microbatches
self._while_loop_parallel_iterations = while_loop_parallel_iterations
self._was_compute_gradients_called = False self._was_compute_gradients_called = False
def compute_gradients(self, def compute_gradients(self,
@ -177,10 +183,6 @@ def make_optimizer_class(cls):
return grads_and_vars return grads_and_vars
else: else:
# TF is running in graph mode. Check we did not receive a gradient tape.
if gradient_tape:
raise ValueError('When in graph mode, a tape should not be passed.')
# Note: it would be closer to the correct i.i.d. sampling of records if # Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution, # we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be # although that still wouldn't be quite correct because it would be
@ -206,6 +208,11 @@ def make_optimizer_class(cls):
# This case covers Keras optimizers from optimizers_v2. # This case covers Keras optimizers from optimizers_v2.
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
if gradient_tape:
# This is intended to work for TF2 and may not work for TF1.
with gradient_tape.stop_recording():
grads_list = list(gradient_tape.gradient(mean_loss, var_list))
else:
grads, _ = zip(*compute_gradients_fn( grads, _ = zip(*compute_gradients_fn(
mean_loss, var_list, gate_gradients, aggregation_method, mean_loss, var_list, gate_gradients, aggregation_method,
colocate_gradients_with_ops, grad_loss)) colocate_gradients_with_ops, grad_loss))
@ -233,7 +240,10 @@ def make_optimizer_class(cls):
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
idx = tf.constant(0) idx = tf.constant(0)
_, sample_state = tf.while_loop( _, sample_state = tf.while_loop(
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state]) cond=cond_fn,
body=body_fn,
loop_vars=[idx, sample_state],
parallel_iterations=self._while_loop_parallel_iterations)
grad_sums, self._global_state, _ = ( grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state, self._dp_sum_query.get_noised_result(sample_state,