forked from 626_privacy/tensorflow_privacy
Allow using gradient tape for gradient calculation in graph mode.
PiperOrigin-RevId: 406217961
This commit is contained in:
parent
c530356ae9
commit
c5cb687507
1 changed files with 19 additions and 9 deletions
|
@ -98,6 +98,7 @@ def make_optimizer_class(cls):
|
|||
dp_sum_query,
|
||||
num_microbatches=None,
|
||||
unroll_microbatches=False,
|
||||
while_loop_parallel_iterations=10,
|
||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||
**kwargs):
|
||||
"""Initializes the DPOptimizerClass.
|
||||
|
@ -111,6 +112,10 @@ def make_optimizer_class(cls):
|
|||
unroll_microbatches: If true, processes microbatches within a Python
|
||||
loop instead of a `tf.while_loop`. Can be used if using a
|
||||
`tf.while_loop` raises an exception.
|
||||
while_loop_parallel_iterations: The number of iterations allowed to run
|
||||
in parallel. It must be a positive integer. Applicable only when
|
||||
unroll_microbatches is set to False. It gives users some control over
|
||||
memory consumption.
|
||||
*args: These will be passed on to the base class `__init__` method.
|
||||
**kwargs: These will be passed on to the base class `__init__` method.
|
||||
"""
|
||||
|
@ -122,6 +127,7 @@ def make_optimizer_class(cls):
|
|||
# Beware: When num_microbatches is large (>100), enabling this parameter
|
||||
# may cause an OOM error.
|
||||
self._unroll_microbatches = unroll_microbatches
|
||||
self._while_loop_parallel_iterations = while_loop_parallel_iterations
|
||||
self._was_compute_gradients_called = False
|
||||
|
||||
def compute_gradients(self,
|
||||
|
@ -177,10 +183,6 @@ def make_optimizer_class(cls):
|
|||
return grads_and_vars
|
||||
|
||||
else:
|
||||
# TF is running in graph mode. Check we did not receive a gradient tape.
|
||||
if gradient_tape:
|
||||
raise ValueError('When in graph mode, a tape should not be passed.')
|
||||
|
||||
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||
# we sampled each microbatch from the appropriate binomial distribution,
|
||||
# although that still wouldn't be quite correct because it would be
|
||||
|
@ -206,6 +208,11 @@ def make_optimizer_class(cls):
|
|||
# This case covers Keras optimizers from optimizers_v2.
|
||||
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
|
||||
|
||||
if gradient_tape:
|
||||
# This is intended to work for TF2 and may not work for TF1.
|
||||
with gradient_tape.stop_recording():
|
||||
grads_list = list(gradient_tape.gradient(mean_loss, var_list))
|
||||
else:
|
||||
grads, _ = zip(*compute_gradients_fn(
|
||||
mean_loss, var_list, gate_gradients, aggregation_method,
|
||||
colocate_gradients_with_ops, grad_loss))
|
||||
|
@ -233,7 +240,10 @@ def make_optimizer_class(cls):
|
|||
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
|
||||
idx = tf.constant(0)
|
||||
_, sample_state = tf.while_loop(
|
||||
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
|
||||
cond=cond_fn,
|
||||
body=body_fn,
|
||||
loop_vars=[idx, sample_state],
|
||||
parallel_iterations=self._while_loop_parallel_iterations)
|
||||
|
||||
grad_sums, self._global_state, _ = (
|
||||
self._dp_sum_query.get_noised_result(sample_state,
|
||||
|
|
Loading…
Reference in a new issue