forked from 626_privacy/tensorflow_privacy
Add more documentation for gradient_accumulation_steps
in keras optimizer.
PiperOrigin-RevId: 469310667
This commit is contained in:
parent
9e25eee68b
commit
9f4feade7d
1 changed files with 17 additions and 5 deletions
|
@ -105,11 +105,23 @@ def make_keras_optimizer_class(cls):
|
|||
opt.minimize(loss, var_list=[var])
|
||||
```
|
||||
|
||||
Note that when using this feature effective batch size is
|
||||
`gradient_accumulation_steps * one_step_batch_size` where
|
||||
`one_step_batch_size` size of the batch which is passed to single step
|
||||
of the optimizer. Thus user may have to adjust learning rate, weight decay
|
||||
and possibly other training hyperparameters accordingly.
|
||||
Note that when using this feature,
|
||||
1. effective batch size is `gradient_accumulation_steps * one_step_batch_size`
|
||||
where `one_step_batch_size` is the size of the batch passed to single step
|
||||
of the optimizer. Thus user may have to adjust learning rate, weight decay
|
||||
and possibly other training hyperparameters accordingly.
|
||||
2. effective noise (the noise to be used for privacy computation) is
|
||||
`noise_multiplier * sqrt(gradient_accumulation_steps)`, as the optimizer
|
||||
adds noise of `self._noise_multiplier` to every step. Thus user may have
|
||||
to adjust the `noise_multiplier` or the privacy computation.
|
||||
Additionally, user may need to adjust the batch size in the data generator,
|
||||
or the number of calls to the data generator, depending on the training
|
||||
framework used. For example, when using Keras model.fit(...) with a
|
||||
user-defined data generator, one may need to make the data generator return
|
||||
`one_step_batch_size` examples each time, and scale the `steps_per_epoch`
|
||||
by `gradient_accumulation_steps`. This is because the data generator is
|
||||
called `steps_per_epoch` times per epoch, and one call only returns
|
||||
`one_step_batch_size` (instead of `effective_batch_size`) examples now.
|
||||
""".format(
|
||||
base_class='tf.keras.optimizers.' + cls.__name__,
|
||||
short_base_class=cls.__name__,
|
||||
|
|
Loading…
Reference in a new issue