forked from 626_privacy/tensorflow_privacy
Apply fix to apply_gradients method in vectorized DP Keras optimizer that affected gradient aggregation in multi-replica training.
PiperOrigin-RevId: 417506496
This commit is contained in:
parent
31f110698d
commit
347b99d412
1 changed files with 2 additions and 4 deletions
|
@ -230,17 +230,15 @@ def make_vectorized_keras_optimizer_class(cls):
|
||||||
clipped_grads)
|
clipped_grads)
|
||||||
return final_grads
|
return final_grads
|
||||||
|
|
||||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
def apply_gradients(self, *args, **kwargs):
|
||||||
"""DP-SGD version of base class method."""
|
"""DP-SGD version of base class method."""
|
||||||
|
|
||||||
assert self._was_dp_gradients_called, (
|
assert self._was_dp_gradients_called, (
|
||||||
'Neither _compute_gradients() or get_gradients() on the '
|
'Neither _compute_gradients() or get_gradients() on the '
|
||||||
'differentially private optimizer was called. This means the '
|
'differentially private optimizer was called. This means the '
|
||||||
'training is not differentially private. It may be the case that '
|
'training is not differentially private. It may be the case that '
|
||||||
'you need to upgrade to TF 2.4 or higher to use this particular '
|
'you need to upgrade to TF 2.4 or higher to use this particular '
|
||||||
'optimizer.')
|
'optimizer.')
|
||||||
return super(DPOptimizerClass,
|
return super(DPOptimizerClass, self).apply_gradients(*args, **kwargs)
|
||||||
self).apply_gradients(grads_and_vars, global_step, name)
|
|
||||||
|
|
||||||
return DPOptimizerClass
|
return DPOptimizerClass
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue