diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py index 5aed2da..71321e4 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_vectorized.py @@ -230,17 +230,15 @@ def make_vectorized_keras_optimizer_class(cls): clipped_grads) return final_grads - def apply_gradients(self, grads_and_vars, global_step=None, name=None): + def apply_gradients(self, *args, **kwargs): """DP-SGD version of base class method.""" - assert self._was_dp_gradients_called, ( 'Neither _compute_gradients() or get_gradients() on the ' 'differentially private optimizer was called. This means the ' 'training is not differentially private. It may be the case that ' 'you need to upgrade to TF 2.4 or higher to use this particular ' 'optimizer.') - return super(DPOptimizerClass, - self).apply_gradients(grads_and_vars, global_step, name) + return super(DPOptimizerClass, self).apply_gradients(*args, **kwargs) return DPOptimizerClass