Add more documentation for gradient_accumulation_steps in keras optimizer.

PiperOrigin-RevId: 469310667
This commit is contained in:
Shuang Song 2022-08-22 16:16:12 -07:00 committed by A. Unique TensorFlower
parent 9e25eee68b
commit 9f4feade7d

View file

@ -105,11 +105,23 @@ def make_keras_optimizer_class(cls):
opt.minimize(loss, var_list=[var]) opt.minimize(loss, var_list=[var])
``` ```
Note that when using this feature effective batch size is Note that when using this feature,
`gradient_accumulation_steps * one_step_batch_size` where 1. effective batch size is `gradient_accumulation_steps * one_step_batch_size`
`one_step_batch_size` size of the batch which is passed to single step where `one_step_batch_size` is the size of the batch passed to single step
of the optimizer. Thus user may have to adjust learning rate, weight decay of the optimizer. Thus user may have to adjust learning rate, weight decay
and possibly other training hyperparameters accordingly. and possibly other training hyperparameters accordingly.
2. effective noise (the noise to be used for privacy computation) is
`noise_multiplier * sqrt(gradient_accumulation_steps)`, as the optimizer
adds noise of `self._noise_multiplier` to every step. Thus user may have
to adjust the `noise_multiplier` or the privacy computation.
Additionally, user may need to adjust the batch size in the data generator,
or the number of calls to the data generator, depending on the training
framework used. For example, when using Keras model.fit(...) with a
user-defined data generator, one may need to make the data generator return
`one_step_batch_size` examples each time, and scale the `steps_per_epoch`
by `gradient_accumulation_steps`. This is because the data generator is
called `steps_per_epoch` times per epoch, and one call only returns
`one_step_batch_size` (instead of `effective_batch_size`) examples now.
""".format( """.format(
base_class='tf.keras.optimizers.' + cls.__name__, base_class='tf.keras.optimizers.' + cls.__name__,
short_base_class=cls.__name__, short_base_class=cls.__name__,