Improved conversion from Renyi DP to approx DP
PiperOrigin-RevId: 349557544
This commit is contained in:
parent
8d53d8cc59
commit
be8175bfac
3 changed files with 150 additions and 22 deletions
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
@ -26,16 +28,36 @@ from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
|
|||
class ComputeDpSgdPrivacyTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('Test0', 60000, 150, 1.3, 15, 1e-5, 0.941870567, 19.0),
|
||||
('Test1', 100000, 100, 1.0, 30, 1e-7, 1.70928734, 13.0),
|
||||
('Test2', 100000000, 1024, 0.1, 10, 1e-7, 5907984.81339406, 1.25),
|
||||
('Test0', 60000, 150, 1.3, 15, 1e-5, 0.7242234026109595, 19.0),
|
||||
('Test1', 100000, 100, 1.0, 30, 1e-7, 1.4154988495444845, 13.0),
|
||||
('Test2', 100000000, 1024, 0.1, 10, 1e-7, 5907982.31138195, 1.25),
|
||||
)
|
||||
def test_compute_dp_sgd_privacy(self, n, batch_size, noise_multiplier, epochs,
|
||||
delta, expected_eps, expected_order):
|
||||
eps, order = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
|
||||
n, batch_size, noise_multiplier, epochs, delta)
|
||||
self.assertAlmostEqual(eps, expected_eps)
|
||||
self.assertAlmostEqual(order, expected_order)
|
||||
self.assertEqual(order, expected_order)
|
||||
|
||||
# We perform an additional sanity check on the hard-coded test values.
|
||||
# We do a back-of-the-envelope calculation to obtain a lower bound.
|
||||
# Specifically, we make the approximation that subsampling a q-fraction is
|
||||
# equivalent to multiplying noise scale by 1/q.
|
||||
# This is only an approximation, but can be justified by the central limit
|
||||
# theorem in the Gaussian Differential Privacy framework; see
|
||||
# https://arxiv.org/1911.11607
|
||||
# The approximation error is one-sided and provides a lower bound, which is
|
||||
# the basis of this sanity check. This is confirmed in the above paper.
|
||||
q = batch_size / n
|
||||
steps = epochs * n / batch_size
|
||||
sigma = noise_multiplier * math.sqrt(steps) /q
|
||||
# We compute the optimal guarantee for Gaussian
|
||||
# using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2).
|
||||
low_delta = .5*math.erfc((eps*sigma-.5/sigma)/math.sqrt(2))
|
||||
if eps < 100: # Skip this if it causes overflow; error is minor.
|
||||
low_delta -= math.exp(eps)*.5*math.erfc((eps*sigma+.5/sigma)/math.sqrt(2))
|
||||
self.assertLessEqual(low_delta, delta)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -188,12 +188,34 @@ def _compute_delta(orders, rdp, eps):
|
|||
orders_vec = np.atleast_1d(orders)
|
||||
rdp_vec = np.atleast_1d(rdp)
|
||||
|
||||
if eps < 0:
|
||||
raise ValueError("Value of privacy loss bound epsilon must be >=0.")
|
||||
if len(orders_vec) != len(rdp_vec):
|
||||
raise ValueError("Input lists must have the same length.")
|
||||
|
||||
deltas = np.exp((rdp_vec - eps) * (orders_vec - 1))
|
||||
idx_opt = np.argmin(deltas)
|
||||
return min(deltas[idx_opt], 1.), orders_vec[idx_opt]
|
||||
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
|
||||
# delta = min( np.exp((rdp_vec - eps) * (orders_vec - 1)) )
|
||||
|
||||
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4):
|
||||
logdeltas = [] # work in log space to avoid overflows
|
||||
for (a, r) in zip(orders_vec, rdp_vec):
|
||||
if a < 1: raise ValueError("Renyi divergence order must be >=1.")
|
||||
if r < 0: raise ValueError("Renyi divergence must be >=0.")
|
||||
# For small alpha, we are better of with bound via KL divergence:
|
||||
# delta <= sqrt(1-exp(-KL)).
|
||||
# Take a min of the two bounds.
|
||||
logdelta = 0.5*math.log1p(-math.exp(-r))
|
||||
if a > 1.01:
|
||||
# This bound is not numerically stable as alpha->1.
|
||||
# Thus we have a min value for alpha.
|
||||
# The bound is also not useful for small alpha, so doesn't matter.
|
||||
rdp_bound = (a - 1) * (r - eps + math.log1p(-1/a)) - math.log(a)
|
||||
logdelta = min(logdelta, rdp_bound)
|
||||
|
||||
logdeltas.append(logdelta)
|
||||
|
||||
idx_opt = np.argmin(logdeltas)
|
||||
return min(math.exp(logdeltas[idx_opt]), 1.), orders_vec[idx_opt]
|
||||
|
||||
|
||||
def _compute_eps(orders, rdp, delta):
|
||||
|
@ -214,13 +236,37 @@ def _compute_eps(orders, rdp, delta):
|
|||
orders_vec = np.atleast_1d(orders)
|
||||
rdp_vec = np.atleast_1d(rdp)
|
||||
|
||||
if delta <= 0:
|
||||
raise ValueError("Privacy failure probability bound delta must be >0.")
|
||||
if len(orders_vec) != len(rdp_vec):
|
||||
raise ValueError("Input lists must have the same length.")
|
||||
|
||||
eps = rdp_vec - math.log(delta) / (orders_vec - 1)
|
||||
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
|
||||
# eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) )
|
||||
|
||||
idx_opt = np.nanargmin(eps) # Ignore NaNs
|
||||
return eps[idx_opt], orders_vec[idx_opt]
|
||||
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4).
|
||||
# Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1).
|
||||
eps_vec = []
|
||||
for (a, r) in zip(orders_vec, rdp_vec):
|
||||
if a < 1: raise ValueError("Renyi divergence order must be >=1.")
|
||||
if r < 0: raise ValueError("Renyi divergence must be >=0.")
|
||||
|
||||
if delta**2 + math.expm1(-r) >= 0:
|
||||
# In this case, we can simply bound via KL divergence:
|
||||
# delta <= sqrt(1-exp(-KL)).
|
||||
eps = 0 # No need to try further computation if we have eps = 0.
|
||||
elif a > 1.01:
|
||||
# This bound is not numerically stable as alpha->1.
|
||||
# Thus we have a min value of alpha.
|
||||
# The bound is also not useful for small alpha, so doesn't matter.
|
||||
eps = r + math.log1p(-1/a) - math.log(delta * a) / (a - 1)
|
||||
else:
|
||||
# In this case we can't do anything. E.g., asking for delta = 0.
|
||||
eps = np.inf
|
||||
eps_vec.append(eps)
|
||||
|
||||
idx_opt = np.argmin(eps_vec)
|
||||
return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
|
||||
|
||||
|
||||
def _compute_rdp(q, sigma, alpha):
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import sys
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -130,19 +131,36 @@ class TestGaussianMoments(parameterized.TestCase):
|
|||
|
||||
def test_get_privacy_spent_check_target_delta(self):
|
||||
orders = range(2, 33)
|
||||
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
|
||||
rdp = [1.1 for o in orders] # Constant corresponds to pure DP.
|
||||
eps, _, opt_order = rdp_accountant.get_privacy_spent(
|
||||
orders, rdp, target_delta=1e-5)
|
||||
self.assertAlmostEqual(eps, 1.258575, places=5)
|
||||
self.assertEqual(opt_order, 20)
|
||||
# Since rdp is constant, it should always pick the largest order.
|
||||
self.assertEqual(opt_order, 32)
|
||||
# Knowing the optimal order, we can calculate eps by hand.
|
||||
self.assertAlmostEqual(eps, 1.32783806176)
|
||||
|
||||
# Second test for Gaussian noise (with no subsampling):
|
||||
orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of orders.
|
||||
rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders)
|
||||
# Scale is chosen to obtain exactly (1,1e-6)-DP.
|
||||
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6)
|
||||
self.assertAlmostEqual(eps, 1)
|
||||
|
||||
def test_get_privacy_spent_check_target_eps(self):
|
||||
orders = range(2, 33)
|
||||
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
|
||||
rdp = [1.1 for o in orders] # Constant corresponds to pure DP.
|
||||
_, delta, opt_order = rdp_accountant.get_privacy_spent(
|
||||
orders, rdp, target_eps=1.258575)
|
||||
orders, rdp, target_eps=1.32783806176)
|
||||
# Since rdp is constant, it should always pick the largest order.
|
||||
self.assertEqual(opt_order, 32)
|
||||
self.assertAlmostEqual(delta, 1e-5)
|
||||
self.assertEqual(opt_order, 20)
|
||||
|
||||
# Second test for Gaussian noise (with no subsampling):
|
||||
orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of order.
|
||||
rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders)
|
||||
# Scale is chosen to obtain exactly (1,1e-6)-DP.
|
||||
_, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_eps=1)
|
||||
self.assertAlmostEqual(delta, 1e-6)
|
||||
|
||||
def test_check_composition(self):
|
||||
orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14.,
|
||||
|
@ -153,17 +171,59 @@ class TestGaussianMoments(parameterized.TestCase):
|
|||
steps=40000,
|
||||
orders=orders)
|
||||
|
||||
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
|
||||
target_delta=1e-6)
|
||||
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6)
|
||||
|
||||
rdp += rdp_accountant.compute_rdp(q=0.1,
|
||||
noise_multiplier=2,
|
||||
steps=100,
|
||||
orders=orders)
|
||||
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
|
||||
target_delta=1e-5)
|
||||
self.assertAlmostEqual(eps, 8.509656, places=5)
|
||||
self.assertEqual(opt_order, 2.5)
|
||||
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-5)
|
||||
# These tests use the old RDP -> approx DP conversion
|
||||
# self.assertAlmostEqual(eps, 8.509656, places=5)
|
||||
# self.assertEqual(opt_order, 2.5)
|
||||
# But these still provide an upper bound
|
||||
self.assertLessEqual(eps, 8.509656)
|
||||
|
||||
def test_get_privacy_spent_consistency(self):
|
||||
orders = range(2, 50) # Large range of orders (helps test for overflows).
|
||||
for q in [0.01, 0.1, 0.8, 1.]: # Different subsampling rates.
|
||||
for multiplier in [0.1, 1., 3., 10., 100.]: # Different noise scales.
|
||||
rdp = rdp_accountant.compute_rdp(q, multiplier, 1, orders)
|
||||
for delta in [.9, .5, .1, .01, 1e-3, 1e-4, 1e-5, 1e-6, 1e-9, 1e-12]:
|
||||
eps1, delta1, ord1 = rdp_accountant.get_privacy_spent(
|
||||
orders, rdp, target_delta=delta)
|
||||
eps2, delta2, ord2 = rdp_accountant.get_privacy_spent(
|
||||
orders, rdp, target_eps=eps1)
|
||||
self.assertEqual(delta1, delta)
|
||||
self.assertEqual(eps2, eps1)
|
||||
if eps1 != 0:
|
||||
self.assertEqual(ord1, ord2)
|
||||
self.assertAlmostEqual(delta, delta2)
|
||||
else: # This is a degenerate case; we won't have consistency.
|
||||
self.assertLessEqual(delta2, delta)
|
||||
|
||||
def test_get_privacy_spent_gaussian(self):
|
||||
# Compare the optimal bound for Gaussian with the one derived from RDP.
|
||||
# Also compare the RDP upper bound with the "standard" upper bound.
|
||||
orders = [0.1*x for x in range(10, 505)]
|
||||
eps_vec = [0.1*x for x in range(500)]
|
||||
rdp = rdp_accountant.compute_rdp(1, 1, 1, orders)
|
||||
for eps in eps_vec:
|
||||
_, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp,
|
||||
target_eps=eps)
|
||||
# For comparison, we compute the optimal guarantee for Gaussian
|
||||
# using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2).
|
||||
delta0 = math.erfc((eps-.5)/math.sqrt(2))/2
|
||||
delta0 = delta0 - math.exp(eps)*math.erfc((eps+.5)/math.sqrt(2))/2
|
||||
self.assertLessEqual(delta0, delta+1e-300) # need tolerance 10^-300
|
||||
|
||||
# Compute the "standard" upper bound, which should be an upper bound.
|
||||
# Note, if orders is too sparse, this will NOT be an upper bound.
|
||||
if eps >= 0.5:
|
||||
delta1 = math.exp(-0.5*(eps-0.5)**2)
|
||||
else:
|
||||
delta1 = 1
|
||||
self.assertLessEqual(delta, delta1+1e-300)
|
||||
|
||||
def test_compute_rdp_from_ledger(self):
|
||||
orders = range(2, 33)
|
||||
|
|
Loading…
Reference in a new issue