Apply fix to apply_gradients method in vectorized DP Keras optimizer that affected gradient aggregation in multi-replica training.

PiperOrigin-RevId: 417506496
This commit is contained in:
Steve Chien 2021-12-20 17:10:04 -08:00 committed by A. Unique TensorFlower
parent 31f110698d
commit 347b99d412

View file

@ -230,17 +230,15 @@ def make_vectorized_keras_optimizer_class(cls):
clipped_grads) clipped_grads)
return final_grads return final_grads
def apply_gradients(self, grads_and_vars, global_step=None, name=None): def apply_gradients(self, *args, **kwargs):
"""DP-SGD version of base class method.""" """DP-SGD version of base class method."""
assert self._was_dp_gradients_called, ( assert self._was_dp_gradients_called, (
'Neither _compute_gradients() or get_gradients() on the ' 'Neither _compute_gradients() or get_gradients() on the '
'differentially private optimizer was called. This means the ' 'differentially private optimizer was called. This means the '
'training is not differentially private. It may be the case that ' 'training is not differentially private. It may be the case that '
'you need to upgrade to TF 2.4 or higher to use this particular ' 'you need to upgrade to TF 2.4 or higher to use this particular '
'optimizer.') 'optimizer.')
return super(DPOptimizerClass, return super(DPOptimizerClass, self).apply_gradients(*args, **kwargs)
self).apply_gradients(grads_and_vars, global_step, name)
return DPOptimizerClass return DPOptimizerClass