Fix of the apply_gradients in Keras DP optimizer.
PiperOrigin-RevId: 417503887
This commit is contained in:
parent
03014d0e99
commit
31f110698d
1 changed files with 2 additions and 3 deletions
|
@ -351,7 +351,7 @@ def make_keras_optimizer_class(cls):
|
||||||
})
|
})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
def apply_gradients(self, *args, **kwargs):
|
||||||
"""DP-SGD version of base class method."""
|
"""DP-SGD version of base class method."""
|
||||||
assert self._was_dp_gradients_called, (
|
assert self._was_dp_gradients_called, (
|
||||||
'Neither _compute_gradients() or get_gradients() on the '
|
'Neither _compute_gradients() or get_gradients() on the '
|
||||||
|
@ -359,8 +359,7 @@ def make_keras_optimizer_class(cls):
|
||||||
'training is not differentially private. It may be the case that '
|
'training is not differentially private. It may be the case that '
|
||||||
'you need to upgrade to TF 2.4 or higher to use this particular '
|
'you need to upgrade to TF 2.4 or higher to use this particular '
|
||||||
'optimizer.')
|
'optimizer.')
|
||||||
return super(DPOptimizerClass,
|
return super(DPOptimizerClass, self).apply_gradients(*args, **kwargs)
|
||||||
self).apply_gradients(grads_and_vars, global_step, name)
|
|
||||||
|
|
||||||
return DPOptimizerClass
|
return DPOptimizerClass
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue