Fix of the apply_gradients in Keras DP optimizer.

PiperOrigin-RevId: 417503887
This commit is contained in:
A. Unique TensorFlower 2021-12-20 16:52:10 -08:00
parent 03014d0e99
commit 31f110698d

View file

@ -351,7 +351,7 @@ def make_keras_optimizer_class(cls):
}) })
return config return config
def apply_gradients(self, grads_and_vars, global_step=None, name=None): def apply_gradients(self, *args, **kwargs):
"""DP-SGD version of base class method.""" """DP-SGD version of base class method."""
assert self._was_dp_gradients_called, ( assert self._was_dp_gradients_called, (
'Neither _compute_gradients() or get_gradients() on the ' 'Neither _compute_gradients() or get_gradients() on the '
@ -359,8 +359,7 @@ def make_keras_optimizer_class(cls):
'training is not differentially private. It may be the case that ' 'training is not differentially private. It may be the case that '
'you need to upgrade to TF 2.4 or higher to use this particular ' 'you need to upgrade to TF 2.4 or higher to use this particular '
'optimizer.') 'optimizer.')
return super(DPOptimizerClass, return super(DPOptimizerClass, self).apply_gradients(*args, **kwargs)
self).apply_gradients(grads_and_vars, global_step, name)
return DPOptimizerClass return DPOptimizerClass