forked from 626_privacy/tensorflow_privacy
For DP Keras optimizers, add assertion that one of the DP-modified gradients methods has been called before apply_gradients(). In particular, this helps catch cases where the user has not yet upgraded to TF 2.4.
PiperOrigin-RevId: 333620379
This commit is contained in:
parent
7c53757250
commit
837e014107
2 changed files with 54 additions and 1 deletions
|
@ -61,10 +61,12 @@ def make_keras_optimizer_class(cls):
|
|||
self._dp_sum_query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
||||
self._global_state = None
|
||||
self._was_dp_gradients_called = False
|
||||
|
||||
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
||||
"""DP version of superclass method."""
|
||||
|
||||
self._was_dp_gradients_called = True
|
||||
# Compute loss.
|
||||
if not callable(loss) and tape is None:
|
||||
raise ValueError('`tape` is required when a `Tensor` loss is passed.')
|
||||
|
@ -120,6 +122,7 @@ def make_keras_optimizer_class(cls):
|
|||
def get_gradients(self, loss, params):
|
||||
"""DP version of superclass method."""
|
||||
|
||||
self._was_dp_gradients_called = True
|
||||
if self._global_state is None:
|
||||
self._global_state = self._dp_sum_query.initial_global_state()
|
||||
|
||||
|
@ -156,6 +159,16 @@ def make_keras_optimizer_class(cls):
|
|||
|
||||
return final_grads
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
assert self._was_dp_gradients_called, (
|
||||
'Neither _compute_gradients() or get_gradients() on the '
|
||||
'differentially private optimizer was called. This means the '
|
||||
'training is not differentially private. It may be the case that '
|
||||
'you need to upgrade to TF 2.4 or higher to use this particular '
|
||||
'optimizer.')
|
||||
return super(DPOptimizerClass,
|
||||
self).apply_gradients(grads_and_vars, global_step, name)
|
||||
|
||||
return DPOptimizerClass
|
||||
|
||||
|
||||
|
|
|
@ -135,6 +135,29 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertNear(
|
||||
np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer),
|
||||
('DPAdagrad', dp_optimizer_keras.DPKerasAdagradOptimizer),
|
||||
('DPAdam', dp_optimizer_keras.DPKerasAdamOptimizer))
|
||||
def testAssertOnNoCallOfComputeGradients(self, cls):
|
||||
"""Tests that assertion fails when DP gradients are not computed."""
|
||||
opt = cls(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
grads_and_vars = tf.Variable([0.0])
|
||||
opt.apply_gradients(grads_and_vars)
|
||||
|
||||
# Expect no exception if _compute_gradients is called.
|
||||
var0 = tf.Variable([0.0])
|
||||
data0 = tf.Variable([[0.0]])
|
||||
loss = lambda: self._loss(data0, var0)
|
||||
grads_and_vars = opt._compute_gradients(loss, [var0])
|
||||
opt.apply_gradients(grads_and_vars)
|
||||
|
||||
|
||||
class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
"""Tests for get_gradient method.
|
||||
|
@ -247,7 +270,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
bias_value / global_norm,
|
||||
atol=0.001)
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches.
|
||||
# Parameters for testing: optimizer, l2_norm_clip, noise_multiplier,
|
||||
# num_microbatches.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||
4.0, 1),
|
||||
|
@ -285,6 +309,22 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
np.std(kernel_value),
|
||||
l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer),
|
||||
('DPAdagrad', dp_optimizer_keras.DPKerasAdagradOptimizer),
|
||||
('DPAdam', dp_optimizer_keras.DPKerasAdamOptimizer))
|
||||
def testAssertOnNoCallOfGetGradients(self, cls):
|
||||
"""Tests that assertion fails when DP gradients are not computed."""
|
||||
opt = cls(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
grads_and_vars = tf.Variable([0.0])
|
||||
opt.apply_gradients(grads_and_vars)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue