From 33bbc87ff269ab69bd0cb1a46f6ca840d9d93948 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Apr 2023 22:16:51 -0700 Subject: [PATCH] Use better group privacy bound in computing user level privacy [TF Privacy] PiperOrigin-RevId: 526852999 --- .../analysis/compute_dp_sgd_privacy_lib.py | 53 +++++++++++++------ .../analysis/compute_dp_sgd_privacy_test.py | 44 +++++++++------ 2 files changed, 65 insertions(+), 32 deletions(-) diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py index f9b708e..ffbaf8f 100644 --- a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py @@ -28,6 +28,11 @@ class UserLevelDPComputationError(Exception): """Error raised if user-level epsilon computation fails.""" +def _logexpm1(x: float) -> float: + """Returns log(exp(x) - 1).""" + return x + math.log(-math.expm1(-x)) + + def _compute_dp_sgd_user_privacy( num_epochs: float, noise_multiplier: float, @@ -100,21 +105,32 @@ def _compute_dp_sgd_user_privacy( # log(G(F(example_delta), example_delta)) - log(user_delta) = 0 # Then we can return user_eps = H(F(example_delta)). - log_k = math.log(max_examples_per_user) target_user_log_delta = math.log(user_delta) - def user_log_delta_gap(example_log_delta): - example_eps = _compute_dp_sgd_example_privacy( - num_epochs, - noise_multiplier, - math.exp(example_log_delta), - used_microbatching, - poisson_subsampling_probability, - ) + # We store all example_eps computed for any example_delta in the following + # method. This is done so that we don't have to recompute values for the same + # delta. + epsilon_cache = dict() - # Estimate user_eps, user_log_delta using Vadhan Lemma 2.2. + def user_log_delta_gap(example_log_delta): + if example_log_delta not in epsilon_cache: + epsilon_cache[example_log_delta] = _compute_dp_sgd_example_privacy( + num_epochs, + noise_multiplier, + math.exp(example_log_delta), + used_microbatching, + poisson_subsampling_probability, + ) + example_eps = epsilon_cache[example_log_delta] + + # Estimate user_eps, user_log_delta using Vadhan Lemma 2.2, using a tighter + # bound seen in the penultimate line of the proof, given as + # user_delta = (example_delta * (exp(k * example_eps) - 1) + # / (exp(example_eps) - 1)) user_eps = max_examples_per_user * example_eps - user_log_delta = log_k + user_eps + example_log_delta + user_log_delta = ( + example_log_delta + _logexpm1(user_eps) - _logexpm1(example_eps) + ) return user_log_delta - target_user_log_delta # We need bounds on the example-level delta. The supplied user-level delta @@ -164,9 +180,16 @@ def _compute_dp_sgd_user_privacy( ) # Vadhan (2017) "The complexity of differential privacy" Lemma 2.2. - # user_delta = k * exp(k * example_eps) * example_delta - # Given example_delta, we can solve for (k * example_eps) = user_eps. - return max(0, target_user_log_delta - log_k - example_log_delta) + if example_log_delta not in epsilon_cache: + epsilon_cache[example_log_delta] = _compute_dp_sgd_example_privacy( + num_epochs, + noise_multiplier, + math.exp(example_log_delta), + used_microbatching, + poisson_subsampling_probability, + ) + example_eps = epsilon_cache[example_log_delta] + return max_examples_per_user * example_eps def _compute_dp_sgd_example_privacy( @@ -354,7 +377,7 @@ using RDP accounting and group privacy:""", paragraphs.append( textwrap.fill( """\ -No user-level privacy guarantee is possible witout a bound on the number of \ +No user-level privacy guarantee is possible without a bound on the number of \ examples per user.""", width=80, ) diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py index 975bd1b..f1fdf15 100644 --- a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py @@ -134,25 +134,35 @@ class ComputeDpSgdPrivacyTest(parameterized.TestCase): self.assertEqual(user_eps, example_eps) @parameterized.parameters((0.9, 2), (1.1, 3), (2.3, 13)) - def test_user_privacy_epsilon_delta_consistency(self, z, k): + def test_user_privacy_epsilon_delta_consistency( + self, noise_multiplier, max_examples_per_user + ): """Tests example/user epsilons consistent with Vadhan (2017) Lemma 2.2.""" num_epochs = 5 - user_delta = 1e-6 + example_delta = 1e-8 q = 2e-4 - user_eps = _user_privacy( - num_epochs, - noise_multiplier=z, - user_delta=user_delta, - max_examples_per_user=k, - poisson_subsampling_probability=q, - ) example_eps = _example_privacy( num_epochs, - noise_multiplier=z, - example_delta=user_delta / (k * math.exp(user_eps)), + noise_multiplier=noise_multiplier, + example_delta=example_delta, poisson_subsampling_probability=q, ) - self.assertAlmostEqual(user_eps, example_eps * k) + + user_delta = math.exp( + math.log(example_delta) + + compute_dp_sgd_privacy_lib._logexpm1( + max_examples_per_user * example_eps + ) + - compute_dp_sgd_privacy_lib._logexpm1(example_eps) + ) + user_eps = _user_privacy( + num_epochs, + noise_multiplier=noise_multiplier, + user_delta=user_delta, + max_examples_per_user=max_examples_per_user, + poisson_subsampling_probability=q, + ) + self.assertAlmostEqual(user_eps, example_eps * max_examples_per_user) def test_dp_sgd_privacy_statement_no_user_dp(self): statement = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy_statement( @@ -171,7 +181,7 @@ RDP accounting: Epsilon with each example occurring once per epoch: 13.376 Epsilon assuming Poisson sampling (*): 1.616 -No user-level privacy guarantee is possible witout a bound on the number of +No user-level privacy guarantee is possible without a bound on the number of examples per user. (*) Poisson sampling is not usually done in training pipelines, but assuming @@ -201,8 +211,8 @@ RDP accounting: User-level DP with add-or-remove-one adjacency at delta = 1e-06 computed using RDP accounting and group privacy: - Epsilon with each example occurring once per epoch: 113.899 - Epsilon assuming Poisson sampling (*): 8.129 + Epsilon with each example occurring once per epoch: 85.940 + Epsilon assuming Poisson sampling (*): 6.425 (*) Poisson sampling is not usually done in training pipelines, but assuming that the data was randomly shuffled, it is believed the actual epsilon should be @@ -214,11 +224,11 @@ order. def test_dp_sgd_privacy_statement_user_dp_infinite(self): statement = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy_statement( **DP_SGD_STATEMENT_KWARGS, - max_examples_per_user=9, + max_examples_per_user=10, ) expected_statement = """\ DP-SGD performed over 10000 examples with 64 examples per iteration, noise -multiplier 2.0 for 5.0 epochs with microbatching, and at most 9 examples per +multiplier 2.0 for 5.0 epochs with microbatching, and at most 10 examples per user. This privacy guarantee protects the release of all model checkpoints in addition