Allow using gradient tape for gradient calculation in graph mode.
PiperOrigin-RevId: 406217961
This commit is contained in:
parent
c530356ae9
commit
c5cb687507
1 changed files with 19 additions and 9 deletions
|
@ -98,6 +98,7 @@ def make_optimizer_class(cls):
|
||||||
dp_sum_query,
|
dp_sum_query,
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
unroll_microbatches=False,
|
unroll_microbatches=False,
|
||||||
|
while_loop_parallel_iterations=10,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Initializes the DPOptimizerClass.
|
"""Initializes the DPOptimizerClass.
|
||||||
|
@ -111,6 +112,10 @@ def make_optimizer_class(cls):
|
||||||
unroll_microbatches: If true, processes microbatches within a Python
|
unroll_microbatches: If true, processes microbatches within a Python
|
||||||
loop instead of a `tf.while_loop`. Can be used if using a
|
loop instead of a `tf.while_loop`. Can be used if using a
|
||||||
`tf.while_loop` raises an exception.
|
`tf.while_loop` raises an exception.
|
||||||
|
while_loop_parallel_iterations: The number of iterations allowed to run
|
||||||
|
in parallel. It must be a positive integer. Applicable only when
|
||||||
|
unroll_microbatches is set to False. It gives users some control over
|
||||||
|
memory consumption.
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
**kwargs: These will be passed on to the base class `__init__` method.
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
"""
|
"""
|
||||||
|
@ -122,6 +127,7 @@ def make_optimizer_class(cls):
|
||||||
# Beware: When num_microbatches is large (>100), enabling this parameter
|
# Beware: When num_microbatches is large (>100), enabling this parameter
|
||||||
# may cause an OOM error.
|
# may cause an OOM error.
|
||||||
self._unroll_microbatches = unroll_microbatches
|
self._unroll_microbatches = unroll_microbatches
|
||||||
|
self._while_loop_parallel_iterations = while_loop_parallel_iterations
|
||||||
self._was_compute_gradients_called = False
|
self._was_compute_gradients_called = False
|
||||||
|
|
||||||
def compute_gradients(self,
|
def compute_gradients(self,
|
||||||
|
@ -177,10 +183,6 @@ def make_optimizer_class(cls):
|
||||||
return grads_and_vars
|
return grads_and_vars
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# TF is running in graph mode. Check we did not receive a gradient tape.
|
|
||||||
if gradient_tape:
|
|
||||||
raise ValueError('When in graph mode, a tape should not be passed.')
|
|
||||||
|
|
||||||
# Note: it would be closer to the correct i.i.d. sampling of records if
|
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||||
# we sampled each microbatch from the appropriate binomial distribution,
|
# we sampled each microbatch from the appropriate binomial distribution,
|
||||||
# although that still wouldn't be quite correct because it would be
|
# although that still wouldn't be quite correct because it would be
|
||||||
|
@ -206,10 +208,15 @@ def make_optimizer_class(cls):
|
||||||
# This case covers Keras optimizers from optimizers_v2.
|
# This case covers Keras optimizers from optimizers_v2.
|
||||||
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
|
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
|
||||||
|
|
||||||
grads, _ = zip(*compute_gradients_fn(
|
if gradient_tape:
|
||||||
mean_loss, var_list, gate_gradients, aggregation_method,
|
# This is intended to work for TF2 and may not work for TF1.
|
||||||
colocate_gradients_with_ops, grad_loss))
|
with gradient_tape.stop_recording():
|
||||||
grads_list = list(grads)
|
grads_list = list(gradient_tape.gradient(mean_loss, var_list))
|
||||||
|
else:
|
||||||
|
grads, _ = zip(*compute_gradients_fn(
|
||||||
|
mean_loss, var_list, gate_gradients, aggregation_method,
|
||||||
|
colocate_gradients_with_ops, grad_loss))
|
||||||
|
grads_list = list(grads)
|
||||||
|
|
||||||
sample_state = self._dp_sum_query.accumulate_record(
|
sample_state = self._dp_sum_query.accumulate_record(
|
||||||
sample_params, sample_state, grads_list)
|
sample_params, sample_state, grads_list)
|
||||||
|
@ -233,7 +240,10 @@ def make_optimizer_class(cls):
|
||||||
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
|
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
|
||||||
idx = tf.constant(0)
|
idx = tf.constant(0)
|
||||||
_, sample_state = tf.while_loop(
|
_, sample_state = tf.while_loop(
|
||||||
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
|
cond=cond_fn,
|
||||||
|
body=body_fn,
|
||||||
|
loop_vars=[idx, sample_state],
|
||||||
|
parallel_iterations=self._while_loop_parallel_iterations)
|
||||||
|
|
||||||
grad_sums, self._global_state, _ = (
|
grad_sums, self._global_state, _ = (
|
||||||
self._dp_sum_query.get_noised_result(sample_state,
|
self._dp_sum_query.get_noised_result(sample_state,
|
||||||
|
|
Loading…
Reference in a new issue