From c5c807807ff4ee5ee2c602cc9a83c98e515e64e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Apr 2020 08:33:34 -0700 Subject: [PATCH] Add assert that the training is private. In Keras training in TF 2.0+, compute_gradients() is not called but apply_gradients() is called. W/o calling compute_gradients() dp gradient is not computed, and a normal gradient is used. PiperOrigin-RevId: 307822742 --- .../privacy/optimizers/dp_optimizer.py | 11 +++++++++++ .../privacy/optimizers/dp_optimizer_test.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py index e05839a..91b72da 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py @@ -67,6 +67,7 @@ def make_optimizer_class(cls): # Beware: When num_microbatches is large (>100), enabling this parameter # may cause an OOM error. self._unroll_microbatches = unroll_microbatches + self._was_compute_gradients_called = False def compute_gradients(self, loss, @@ -76,6 +77,7 @@ def make_optimizer_class(cls): colocate_gradients_with_ops=False, grad_loss=None, gradient_tape=None): + self._was_compute_gradients_called = True if callable(loss): # TF is running in Eager mode, check we received a vanilla tape. if not gradient_tape: @@ -175,6 +177,15 @@ def make_optimizer_class(cls): return list(zip(final_grads, var_list)) + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + assert self._was_compute_gradients_called, ( + 'compute_gradients() on the differentially private optimizer was not' + ' called. Which means that the training is not differentially ' + 'private. It happens for example in Keras training in TensorFlow ' + '2.0+.') + return super(DPOptimizerClass, + self).apply_gradients(grads_and_vars, global_step, name) + return DPOptimizerClass diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py index c9c214d..ebd2261 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py @@ -238,6 +238,24 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): # Test standard deviation is close to l2_norm_clip * noise_multiplier. self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) + def testAssertOnNoCallOfComputeGradients(self, cls): + dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0) + opt = cls(dp_sum_query, num_microbatches=1, learning_rate=1.0) + + with self.assertRaises(AssertionError): + grads_and_vars = tf.Variable([0.0]) + opt.apply_gradients(grads_and_vars) + + # Expect no exception if compute_gradients is called. + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + grads_and_vars = opt.compute_gradients(self._loss(data0, var0), [var0]) + opt.apply_gradients(grads_and_vars) + if __name__ == '__main__': tf.test.main()