From be8175bfaca97916417d6a251b00444577dbf197 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 Dec 2020 07:42:32 -0800 Subject: [PATCH] Improved conversion from Renyi DP to approx DP PiperOrigin-RevId: 349557544 --- .../analysis/compute_dp_sgd_privacy_test.py | 30 ++++++- .../privacy/analysis/rdp_accountant.py | 58 +++++++++++-- .../privacy/analysis/rdp_accountant_test.py | 84 ++++++++++++++++--- 3 files changed, 150 insertions(+), 22 deletions(-) diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py index 7f4b66c..267bd06 100644 --- a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + from absl.testing import absltest from absl.testing import parameterized @@ -26,16 +28,36 @@ from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib class ComputeDpSgdPrivacyTest(parameterized.TestCase): @parameterized.named_parameters( - ('Test0', 60000, 150, 1.3, 15, 1e-5, 0.941870567, 19.0), - ('Test1', 100000, 100, 1.0, 30, 1e-7, 1.70928734, 13.0), - ('Test2', 100000000, 1024, 0.1, 10, 1e-7, 5907984.81339406, 1.25), + ('Test0', 60000, 150, 1.3, 15, 1e-5, 0.7242234026109595, 19.0), + ('Test1', 100000, 100, 1.0, 30, 1e-7, 1.4154988495444845, 13.0), + ('Test2', 100000000, 1024, 0.1, 10, 1e-7, 5907982.31138195, 1.25), ) def test_compute_dp_sgd_privacy(self, n, batch_size, noise_multiplier, epochs, delta, expected_eps, expected_order): eps, order = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy( n, batch_size, noise_multiplier, epochs, delta) self.assertAlmostEqual(eps, expected_eps) - self.assertAlmostEqual(order, expected_order) + self.assertEqual(order, expected_order) + + # We perform an additional sanity check on the hard-coded test values. + # We do a back-of-the-envelope calculation to obtain a lower bound. + # Specifically, we make the approximation that subsampling a q-fraction is + # equivalent to multiplying noise scale by 1/q. + # This is only an approximation, but can be justified by the central limit + # theorem in the Gaussian Differential Privacy framework; see + # https://arxiv.org/1911.11607 + # The approximation error is one-sided and provides a lower bound, which is + # the basis of this sanity check. This is confirmed in the above paper. + q = batch_size / n + steps = epochs * n / batch_size + sigma = noise_multiplier * math.sqrt(steps) /q + # We compute the optimal guarantee for Gaussian + # using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2). + low_delta = .5*math.erfc((eps*sigma-.5/sigma)/math.sqrt(2)) + if eps < 100: # Skip this if it causes overflow; error is minor. + low_delta -= math.exp(eps)*.5*math.erfc((eps*sigma+.5/sigma)/math.sqrt(2)) + self.assertLessEqual(low_delta, delta) + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py index 5e6e927..59e33c0 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -188,12 +188,34 @@ def _compute_delta(orders, rdp, eps): orders_vec = np.atleast_1d(orders) rdp_vec = np.atleast_1d(rdp) + if eps < 0: + raise ValueError("Value of privacy loss bound epsilon must be >=0.") if len(orders_vec) != len(rdp_vec): raise ValueError("Input lists must have the same length.") - deltas = np.exp((rdp_vec - eps) * (orders_vec - 1)) - idx_opt = np.argmin(deltas) - return min(deltas[idx_opt], 1.), orders_vec[idx_opt] + # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): + # delta = min( np.exp((rdp_vec - eps) * (orders_vec - 1)) ) + + # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4): + logdeltas = [] # work in log space to avoid overflows + for (a, r) in zip(orders_vec, rdp_vec): + if a < 1: raise ValueError("Renyi divergence order must be >=1.") + if r < 0: raise ValueError("Renyi divergence must be >=0.") + # For small alpha, we are better of with bound via KL divergence: + # delta <= sqrt(1-exp(-KL)). + # Take a min of the two bounds. + logdelta = 0.5*math.log1p(-math.exp(-r)) + if a > 1.01: + # This bound is not numerically stable as alpha->1. + # Thus we have a min value for alpha. + # The bound is also not useful for small alpha, so doesn't matter. + rdp_bound = (a - 1) * (r - eps + math.log1p(-1/a)) - math.log(a) + logdelta = min(logdelta, rdp_bound) + + logdeltas.append(logdelta) + + idx_opt = np.argmin(logdeltas) + return min(math.exp(logdeltas[idx_opt]), 1.), orders_vec[idx_opt] def _compute_eps(orders, rdp, delta): @@ -214,13 +236,37 @@ def _compute_eps(orders, rdp, delta): orders_vec = np.atleast_1d(orders) rdp_vec = np.atleast_1d(rdp) + if delta <= 0: + raise ValueError("Privacy failure probability bound delta must be >0.") if len(orders_vec) != len(rdp_vec): raise ValueError("Input lists must have the same length.") - eps = rdp_vec - math.log(delta) / (orders_vec - 1) + # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): + # eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) ) - idx_opt = np.nanargmin(eps) # Ignore NaNs - return eps[idx_opt], orders_vec[idx_opt] + # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4). + # Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1). + eps_vec = [] + for (a, r) in zip(orders_vec, rdp_vec): + if a < 1: raise ValueError("Renyi divergence order must be >=1.") + if r < 0: raise ValueError("Renyi divergence must be >=0.") + + if delta**2 + math.expm1(-r) >= 0: + # In this case, we can simply bound via KL divergence: + # delta <= sqrt(1-exp(-KL)). + eps = 0 # No need to try further computation if we have eps = 0. + elif a > 1.01: + # This bound is not numerically stable as alpha->1. + # Thus we have a min value of alpha. + # The bound is also not useful for small alpha, so doesn't matter. + eps = r + math.log1p(-1/a) - math.log(delta * a) / (a - 1) + else: + # In this case we can't do anything. E.g., asking for delta = 0. + eps = np.inf + eps_vec.append(eps) + + idx_opt = np.argmin(eps_vec) + return max(0, eps_vec[idx_opt]), orders_vec[idx_opt] def _compute_rdp(q, sigma, alpha): diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py index a10a301..eda62bb 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import sys from absl.testing import absltest @@ -130,19 +131,36 @@ class TestGaussianMoments(parameterized.TestCase): def test_get_privacy_spent_check_target_delta(self): orders = range(2, 33) - rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders) + rdp = [1.1 for o in orders] # Constant corresponds to pure DP. eps, _, opt_order = rdp_accountant.get_privacy_spent( orders, rdp, target_delta=1e-5) - self.assertAlmostEqual(eps, 1.258575, places=5) - self.assertEqual(opt_order, 20) + # Since rdp is constant, it should always pick the largest order. + self.assertEqual(opt_order, 32) + # Knowing the optimal order, we can calculate eps by hand. + self.assertAlmostEqual(eps, 1.32783806176) + + # Second test for Gaussian noise (with no subsampling): + orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of orders. + rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders) + # Scale is chosen to obtain exactly (1,1e-6)-DP. + eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6) + self.assertAlmostEqual(eps, 1) def test_get_privacy_spent_check_target_eps(self): orders = range(2, 33) - rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders) + rdp = [1.1 for o in orders] # Constant corresponds to pure DP. _, delta, opt_order = rdp_accountant.get_privacy_spent( - orders, rdp, target_eps=1.258575) + orders, rdp, target_eps=1.32783806176) + # Since rdp is constant, it should always pick the largest order. + self.assertEqual(opt_order, 32) self.assertAlmostEqual(delta, 1e-5) - self.assertEqual(opt_order, 20) + + # Second test for Gaussian noise (with no subsampling): + orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of order. + rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders) + # Scale is chosen to obtain exactly (1,1e-6)-DP. + _, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_eps=1) + self.assertAlmostEqual(delta, 1e-6) def test_check_composition(self): orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14., @@ -153,17 +171,59 @@ class TestGaussianMoments(parameterized.TestCase): steps=40000, orders=orders) - eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp, - target_delta=1e-6) + eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6) rdp += rdp_accountant.compute_rdp(q=0.1, noise_multiplier=2, steps=100, orders=orders) - eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp, - target_delta=1e-5) - self.assertAlmostEqual(eps, 8.509656, places=5) - self.assertEqual(opt_order, 2.5) + eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-5) + # These tests use the old RDP -> approx DP conversion + # self.assertAlmostEqual(eps, 8.509656, places=5) + # self.assertEqual(opt_order, 2.5) + # But these still provide an upper bound + self.assertLessEqual(eps, 8.509656) + + def test_get_privacy_spent_consistency(self): + orders = range(2, 50) # Large range of orders (helps test for overflows). + for q in [0.01, 0.1, 0.8, 1.]: # Different subsampling rates. + for multiplier in [0.1, 1., 3., 10., 100.]: # Different noise scales. + rdp = rdp_accountant.compute_rdp(q, multiplier, 1, orders) + for delta in [.9, .5, .1, .01, 1e-3, 1e-4, 1e-5, 1e-6, 1e-9, 1e-12]: + eps1, delta1, ord1 = rdp_accountant.get_privacy_spent( + orders, rdp, target_delta=delta) + eps2, delta2, ord2 = rdp_accountant.get_privacy_spent( + orders, rdp, target_eps=eps1) + self.assertEqual(delta1, delta) + self.assertEqual(eps2, eps1) + if eps1 != 0: + self.assertEqual(ord1, ord2) + self.assertAlmostEqual(delta, delta2) + else: # This is a degenerate case; we won't have consistency. + self.assertLessEqual(delta2, delta) + + def test_get_privacy_spent_gaussian(self): + # Compare the optimal bound for Gaussian with the one derived from RDP. + # Also compare the RDP upper bound with the "standard" upper bound. + orders = [0.1*x for x in range(10, 505)] + eps_vec = [0.1*x for x in range(500)] + rdp = rdp_accountant.compute_rdp(1, 1, 1, orders) + for eps in eps_vec: + _, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp, + target_eps=eps) + # For comparison, we compute the optimal guarantee for Gaussian + # using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2). + delta0 = math.erfc((eps-.5)/math.sqrt(2))/2 + delta0 = delta0 - math.exp(eps)*math.erfc((eps+.5)/math.sqrt(2))/2 + self.assertLessEqual(delta0, delta+1e-300) # need tolerance 10^-300 + + # Compute the "standard" upper bound, which should be an upper bound. + # Note, if orders is too sparse, this will NOT be an upper bound. + if eps >= 0.5: + delta1 = math.exp(-0.5*(eps-0.5)**2) + else: + delta1 = 1 + self.assertLessEqual(delta, delta1+1e-300) def test_compute_rdp_from_ledger(self): orders = range(2, 33)