forked from 626_privacy/tensorflow_privacy
Fix of the apply_gradients in Keras DP optimizer.
PiperOrigin-RevId: 417503887
This commit is contained in:
parent
03014d0e99
commit
31f110698d
1 changed files with 2 additions and 3 deletions
|
@ -351,7 +351,7 @@ def make_keras_optimizer_class(cls):
|
|||
})
|
||||
return config
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
def apply_gradients(self, *args, **kwargs):
|
||||
"""DP-SGD version of base class method."""
|
||||
assert self._was_dp_gradients_called, (
|
||||
'Neither _compute_gradients() or get_gradients() on the '
|
||||
|
@ -359,8 +359,7 @@ def make_keras_optimizer_class(cls):
|
|||
'training is not differentially private. It may be the case that '
|
||||
'you need to upgrade to TF 2.4 or higher to use this particular '
|
||||
'optimizer.')
|
||||
return super(DPOptimizerClass,
|
||||
self).apply_gradients(grads_and_vars, global_step, name)
|
||||
return super(DPOptimizerClass, self).apply_gradients(*args, **kwargs)
|
||||
|
||||
return DPOptimizerClass
|
||||
|
||||
|
|
Loading…
Reference in a new issue