Apply fix to apply_gradients method in vectorized DP Keras optimizer that affected gradient aggregation in multi-replica training.
PiperOrigin-RevId: 417506496
This commit is contained in:
parent
31f110698d
commit
347b99d412
1 changed files with 2 additions and 4 deletions
|
@ -230,17 +230,15 @@ def make_vectorized_keras_optimizer_class(cls):
|
|||
clipped_grads)
|
||||
return final_grads
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
def apply_gradients(self, *args, **kwargs):
|
||||
"""DP-SGD version of base class method."""
|
||||
|
||||
assert self._was_dp_gradients_called, (
|
||||
'Neither _compute_gradients() or get_gradients() on the '
|
||||
'differentially private optimizer was called. This means the '
|
||||
'training is not differentially private. It may be the case that '
|
||||
'you need to upgrade to TF 2.4 or higher to use this particular '
|
||||
'optimizer.')
|
||||
return super(DPOptimizerClass,
|
||||
self).apply_gradients(grads_and_vars, global_step, name)
|
||||
return super(DPOptimizerClass, self).apply_gradients(*args, **kwargs)
|
||||
|
||||
return DPOptimizerClass
|
||||
|
||||
|
|
Loading…
Reference in a new issue