Adds functions for more accurate privacy accounting.
Adds function for computation of example-level DP epsilon taking into account microbatching and not assuming Poisson subsampling. Adds function for computation of user-level DP in terms of group privacy. PiperOrigin-RevId: 515114010
This commit is contained in:
parent
4e1fc252e4
commit
61dfbcc1f5
2 changed files with 312 additions and 22 deletions
|
@ -15,43 +15,225 @@
|
|||
"""Library for computing privacy values for DP-SGD."""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from absl import app
|
||||
from absl import logging
|
||||
import dp_accounting
|
||||
from scipy import optimize
|
||||
|
||||
|
||||
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
|
||||
"""Compute and print results of DP-SGD analysis."""
|
||||
class UserLevelDPComputationError(Exception):
|
||||
"""Error raised if user-level epsilon computation fails."""
|
||||
|
||||
accountant = dp_accounting.rdp.RdpAccountant(orders)
|
||||
|
||||
event = dp_accounting.SelfComposedDpEvent(
|
||||
dp_accounting.PoissonSampledDpEvent(q,
|
||||
dp_accounting.GaussianDpEvent(sigma)),
|
||||
steps)
|
||||
def _compute_dp_sgd_user_privacy(
|
||||
num_epochs: float,
|
||||
noise_multiplier: float,
|
||||
user_delta: float,
|
||||
max_examples_per_user: int,
|
||||
used_microbatching: bool = True,
|
||||
poisson_subsampling_probability: Optional[float] = None,
|
||||
) -> float:
|
||||
"""Computes add-or-remove-one-user DP epsilon using group privacy.
|
||||
|
||||
accountant.compose(event)
|
||||
This privacy guarantee uses add-or-remove-one-user adjacency, and protects
|
||||
release of all model checkpoints in addition to the final model.
|
||||
|
||||
eps, opt_order = accountant.get_epsilon_and_optimal_order(delta)
|
||||
Uses Vadhan (2017) "The complexity of differential privacy" Lemma 2.2.
|
||||
|
||||
print(
|
||||
'DP-SGD with sampling rate = {:.3g}% and noise_multiplier = {} iterated'
|
||||
' over {} steps satisfies'.format(100 * q, sigma, steps),
|
||||
end=' ')
|
||||
print('differential privacy with eps = {:.3g} and delta = {}.'.format(
|
||||
eps, delta))
|
||||
print('The optimal RDP order is {}.'.format(opt_order))
|
||||
# TODO(b/271330804): Consider using RDP to compute group privacy.
|
||||
|
||||
if opt_order == max(orders) or opt_order == min(orders):
|
||||
print('The privacy estimate is likely to be improved by expanding '
|
||||
'the set of orders.')
|
||||
We use a line search to identify an example-level delta which, when the lemma
|
||||
is applied, yields the requested user-level delta, then use it to compute the
|
||||
user-level epsilon.
|
||||
|
||||
return eps, opt_order
|
||||
Args:
|
||||
num_epochs: The number of passes over the data. May be fractional.
|
||||
noise_multiplier: The ratio of the noise to the l2 sensitivity.
|
||||
user_delta: The target user-level delta.
|
||||
max_examples_per_user: Upper bound on the number of examples per user.
|
||||
used_microbatching: If true, increases sensitivity by a factor of two.
|
||||
poisson_subsampling_probability: If not None, gives the probability that
|
||||
each record is chosen in a batch. If None, assumes no subsampling.
|
||||
|
||||
Returns:
|
||||
The add-or-remove-one-user DP epsilon value using group privacy.
|
||||
|
||||
Raises:
|
||||
UserLevelDPComputationError: If line search for example-level delta fails.
|
||||
"""
|
||||
if num_epochs <= 0:
|
||||
raise ValueError(f'num_epochs must be positive. Found {num_epochs}.')
|
||||
if noise_multiplier < 0:
|
||||
raise ValueError(
|
||||
f'noise_multiplier must be non-negative. Found {noise_multiplier}.'
|
||||
)
|
||||
if not 0 <= user_delta <= 1:
|
||||
raise ValueError(f'user_delta must be between 0 and 1. Found {user_delta}.')
|
||||
if max_examples_per_user <= 0:
|
||||
raise ValueError(
|
||||
'max_examples_per_user must be a positive integer. Found '
|
||||
f'{max_examples_per_user}.'
|
||||
)
|
||||
|
||||
if max_examples_per_user == 1:
|
||||
# Don't unnecessarily inflate epsilon if one example per user.
|
||||
return _compute_dp_sgd_example_privacy(
|
||||
num_epochs,
|
||||
noise_multiplier,
|
||||
user_delta,
|
||||
used_microbatching,
|
||||
poisson_subsampling_probability,
|
||||
)
|
||||
|
||||
# The computation below to estimate user_eps works as follows.
|
||||
# We have _compute_dp_sgd_example_privacy which maps
|
||||
# F(example_delta) -> example_eps
|
||||
# Vadhan (2017) "The complexity of differential privacy" Lemma 2.2 gives us
|
||||
# G(example_eps, example_delta) -> user_delta
|
||||
# H(example_eps) -> user_eps.
|
||||
# We first identify an example_delta such that
|
||||
# G(F(example_delta), example_delta) = user_delta
|
||||
# Specifically, we use a line search in log space to solve for
|
||||
# log(G(F(example_delta), example_delta)) - log(user_delta) = 0
|
||||
# Then we can return user_eps = H(F(example_delta)).
|
||||
|
||||
log_k = math.log(max_examples_per_user)
|
||||
target_user_log_delta = math.log(user_delta)
|
||||
|
||||
def user_log_delta_gap(example_log_delta):
|
||||
example_eps = _compute_dp_sgd_example_privacy(
|
||||
num_epochs,
|
||||
noise_multiplier,
|
||||
math.exp(example_log_delta),
|
||||
used_microbatching,
|
||||
poisson_subsampling_probability,
|
||||
)
|
||||
|
||||
# Estimate user_eps, user_log_delta using Vadhan Lemma 2.2.
|
||||
user_eps = max_examples_per_user * example_eps
|
||||
user_log_delta = log_k + user_eps + example_log_delta
|
||||
return user_log_delta - target_user_log_delta
|
||||
|
||||
# We need bounds on the example-level delta. The supplied user-level delta
|
||||
# is an upper bound. Search exponentially toward zero for lower bound.
|
||||
example_log_delta_max = target_user_log_delta
|
||||
example_log_delta_min = example_log_delta_max - math.log(10)
|
||||
user_log_delta_gap_min = user_log_delta_gap(example_log_delta_min)
|
||||
while user_log_delta_gap_min > 0:
|
||||
# Assuming that _compute_dp_sgd_example_privacy is decreasing in
|
||||
# example_delta, it is not difficult to show that if user_delta_min
|
||||
# corresponding to example_delta_min is too large, then we must reduce
|
||||
# example_delta by at least a factor of (user_delta / user_delta_min).
|
||||
# In other words, if example_log_delta_min is an upper bound, then so is
|
||||
# example_log_delta_min - user_log_delta_gap_min.
|
||||
example_log_delta_max = example_log_delta_min - user_log_delta_gap_min
|
||||
example_log_delta_min = example_log_delta_max - math.log(10)
|
||||
user_log_delta_gap_min = user_log_delta_gap(example_log_delta_min)
|
||||
if not math.isfinite(user_log_delta_gap_min):
|
||||
# User-level (epsilon, delta) DP is not achievable. This can happen
|
||||
# because as example_delta decreases, example_eps increases. So it is
|
||||
# possible for user_delta (which increases in both example_delta and
|
||||
# example_eps) to diverge to infinity as example_delta goes to zero.
|
||||
logging.warn(
|
||||
(
|
||||
'No upper bound on user-level DP epsilon can be computed with %s '
|
||||
'examples per user.'
|
||||
),
|
||||
max_examples_per_user,
|
||||
)
|
||||
return math.inf
|
||||
|
||||
# By the same logic, we can improve on the lower bound we just found, before
|
||||
# even starting the line search. We actually could do a custom line search
|
||||
# that makes use of this at each step, but brentq should be fast enough.
|
||||
example_log_delta_min -= user_log_delta_gap_min
|
||||
|
||||
example_log_delta, result = optimize.brentq(
|
||||
user_log_delta_gap,
|
||||
example_log_delta_min,
|
||||
example_log_delta_max,
|
||||
full_output=True,
|
||||
)
|
||||
|
||||
if not result.converged:
|
||||
raise UserLevelDPComputationError(
|
||||
'Optimization failed trying to compute user-level DP epsilon.'
|
||||
)
|
||||
|
||||
# Vadhan (2017) "The complexity of differential privacy" Lemma 2.2.
|
||||
# user_delta = k * exp(k * example_eps) * example_delta
|
||||
# Given example_delta, we can solve for (k * example_eps) = user_eps.
|
||||
return max(0, target_user_log_delta - log_k - example_log_delta)
|
||||
|
||||
|
||||
def _compute_dp_sgd_example_privacy(
|
||||
num_epochs: float,
|
||||
noise_multiplier: float,
|
||||
example_delta: float,
|
||||
used_microbatching: bool = True,
|
||||
poisson_subsampling_probability: Optional[float] = None,
|
||||
) -> float:
|
||||
"""Computes add-or-remove-one-example DP epsilon.
|
||||
|
||||
This privacy guarantee uses add-or-remove-one-example adjacency, and protects
|
||||
release of all model checkpoints in addition to the final model.
|
||||
|
||||
Args:
|
||||
num_epochs: The number of passes over the data.
|
||||
noise_multiplier: The ratio of the noise to the l2 sensitivity.
|
||||
example_delta: The target delta.
|
||||
used_microbatching: If true, increases sensitivity by a factor of two.
|
||||
poisson_subsampling_probability: If not None, gives the probability that
|
||||
each record is chosen in a batch. If None, assumes no subsampling.
|
||||
|
||||
Returns:
|
||||
The epsilon value.
|
||||
"""
|
||||
if num_epochs <= 0:
|
||||
raise ValueError(f'num_epochs must be positive. Found {num_epochs}.')
|
||||
if noise_multiplier < 0:
|
||||
raise ValueError(
|
||||
f'noise_multiplier must be non-negative. Found {noise_multiplier}.'
|
||||
)
|
||||
if not 0 <= example_delta <= 1:
|
||||
raise ValueError(f'delta must be between 0 and 1. Found {example_delta}.')
|
||||
|
||||
if used_microbatching:
|
||||
# TODO(b/271462792)
|
||||
noise_multiplier /= 2
|
||||
|
||||
event_ = dp_accounting.GaussianDpEvent(noise_multiplier)
|
||||
if poisson_subsampling_probability is not None:
|
||||
event_ = dp_accounting.PoissonSampledDpEvent(
|
||||
sampling_probability=poisson_subsampling_probability, event=event_
|
||||
)
|
||||
count = int(math.ceil(num_epochs / poisson_subsampling_probability))
|
||||
else:
|
||||
count = int(math.ceil(num_epochs))
|
||||
event_ = dp_accounting.SelfComposedDpEvent(count=count, event=event_)
|
||||
return (
|
||||
dp_accounting.rdp.RdpAccountant() # TODO(b/271341062)
|
||||
.compose(event_)
|
||||
.get_epsilon(example_delta)
|
||||
)
|
||||
|
||||
|
||||
def compute_dp_sgd_privacy(n, batch_size, noise_multiplier, epochs, delta):
|
||||
"""Compute epsilon based on the given hyperparameters.
|
||||
|
||||
This function is deprecated. It does not account for doubling of sensitivity
|
||||
with microbatching, and assumes Poisson subsampling, which is rarely used in
|
||||
practice. (See "How to DP-fy ML: A Practical Guide to Machine Learning with
|
||||
Differential Privacy", https://arxiv.org/abs/2303.00654, Sec 5.6.) Most users
|
||||
should call `compute_dp_sgd_privacy_statement` (which will be added shortly),
|
||||
which provides appropriate context for the guarantee (see the reporting
|
||||
recommendations in "How to DP-fy ML", Sec 5.3). If you need a numeric epsilon
|
||||
value under specific assumptions, it is recommended to use the `dp_accounting`
|
||||
libraries directly to compute epsilon, with the precise and correct
|
||||
assumptions of your application.
|
||||
|
||||
Args:
|
||||
n: Number of examples in the training data.
|
||||
batch_size: Batch size used in training.
|
||||
|
@ -60,13 +242,31 @@ def compute_dp_sgd_privacy(n, batch_size, noise_multiplier, epochs, delta):
|
|||
delta: Value of delta for which to compute epsilon.
|
||||
|
||||
Returns:
|
||||
Value of epsilon corresponding to input hyperparameters.
|
||||
A 2-tuple containing the value of epsilon and the optimal RDP order.
|
||||
"""
|
||||
# TODO(b/265168958): Update this text for `compute_dp_sgd_privacy_statement`.
|
||||
logging.warn(
|
||||
'`compute_dp_sgd_privacy` is deprecated. It does not account '
|
||||
'for doubling of sensitivity with microbatching, and assumes Poisson '
|
||||
'subsampling, which is rarely used in practice. Please use the '
|
||||
'`dp_accounting` libraries directly to compute epsilon, using the '
|
||||
'precise and correct assumptions of your application.'
|
||||
)
|
||||
|
||||
q = batch_size / n # q - the sampling ratio.
|
||||
if q > 1:
|
||||
raise app.UsageError('n must be larger than the batch size.')
|
||||
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
||||
list(range(5, 64)) + [128, 256, 512])
|
||||
steps = int(math.ceil(epochs * n / batch_size))
|
||||
accountant = dp_accounting.rdp.RdpAccountant(orders)
|
||||
|
||||
return apply_dp_sgd_analysis(q, noise_multiplier, steps, orders, delta)
|
||||
event = dp_accounting.SelfComposedDpEvent(
|
||||
dp_accounting.PoissonSampledDpEvent(
|
||||
sampling_probability=q,
|
||||
event=dp_accounting.GaussianDpEvent(noise_multiplier),
|
||||
),
|
||||
steps,
|
||||
)
|
||||
|
||||
return accountant.compose(event).get_epsilon_and_optimal_order(delta)
|
||||
|
|
|
@ -21,6 +21,10 @@ from absl.testing import parameterized
|
|||
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
|
||||
|
||||
|
||||
_example_privacy = compute_dp_sgd_privacy_lib._compute_dp_sgd_example_privacy
|
||||
_user_privacy = compute_dp_sgd_privacy_lib._compute_dp_sgd_user_privacy
|
||||
|
||||
|
||||
class ComputeDpSgdPrivacyTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -55,6 +59,92 @@ class ComputeDpSgdPrivacyTest(parameterized.TestCase):
|
|||
(eps * sigma + .5 / sigma) / math.sqrt(2))
|
||||
self.assertLessEqual(low_delta, delta)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('num_epochs_negative', dict(num_epochs=-1.0)),
|
||||
('noise_multiplier_negative', dict(noise_multiplier=-1.0)),
|
||||
('example_delta_negative', dict(example_delta=-0.5)),
|
||||
('example_delta_excessive', dict(example_delta=1.5)),
|
||||
)
|
||||
def test_compute_dp_sgd_example_privacy_bad_args(self, override_args):
|
||||
args = dict(num_epochs=1.0, noise_multiplier=1.0, example_delta=1.0)
|
||||
args.update(override_args)
|
||||
with self.assertRaises(ValueError):
|
||||
_example_privacy(**args)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('no_microbatching_no_subsampling', False, None, 10.8602036),
|
||||
('microbatching_no_subsampling', True, None, 26.2880374),
|
||||
('no_microbatching_with_subsampling', False, 1e-2, 3.2391922),
|
||||
('microbatching_with_subsampling', True, 1e-2, 22.5970358),
|
||||
)
|
||||
def test_compute_dp_sgd_example_privacy(
|
||||
self, used_microbatching, poisson_subsampling_probability, expected_eps
|
||||
):
|
||||
num_epochs = 1.2
|
||||
noise_multiplier = 0.7
|
||||
example_delta = 1e-5
|
||||
eps = _example_privacy(
|
||||
num_epochs,
|
||||
noise_multiplier,
|
||||
example_delta,
|
||||
used_microbatching,
|
||||
poisson_subsampling_probability,
|
||||
)
|
||||
self.assertAlmostEqual(eps, expected_eps)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('num_epochs_negative', dict(num_epochs=-1.0)),
|
||||
('noise_multiplier_negative', dict(noise_multiplier=-1.0)),
|
||||
('example_delta_negative', dict(user_delta=-0.5)),
|
||||
('example_delta_excessive', dict(user_delta=1.5)),
|
||||
('max_examples_per_user_negative', dict(max_examples_per_user=-1)),
|
||||
)
|
||||
def test_compute_dp_sgd_user_privacy_bad_args(self, override_args):
|
||||
args = dict(
|
||||
num_epochs=1.0,
|
||||
noise_multiplier=1.0,
|
||||
user_delta=1.0,
|
||||
max_examples_per_user=3,
|
||||
)
|
||||
args.update(override_args)
|
||||
with self.assertRaises(ValueError):
|
||||
_user_privacy(**args)
|
||||
|
||||
def test_user_privacy_one_example_per_user(self):
|
||||
num_epochs = 1.2
|
||||
noise_multiplier = 0.7
|
||||
delta = 1e-5
|
||||
|
||||
example_eps = _example_privacy(num_epochs, noise_multiplier, delta)
|
||||
user_eps = _user_privacy(
|
||||
num_epochs,
|
||||
noise_multiplier,
|
||||
delta,
|
||||
max_examples_per_user=1,
|
||||
)
|
||||
self.assertEqual(user_eps, example_eps)
|
||||
|
||||
@parameterized.parameters((0.9, 2), (1.1, 3), (2.3, 13))
|
||||
def test_user_privacy_epsilon_delta_consistency(self, z, k):
|
||||
"""Tests example/user epsilons consistent with Vadhan (2017) Lemma 2.2."""
|
||||
num_epochs = 5
|
||||
user_delta = 1e-6
|
||||
q = 2e-4
|
||||
user_eps = _user_privacy(
|
||||
num_epochs,
|
||||
noise_multiplier=z,
|
||||
user_delta=user_delta,
|
||||
max_examples_per_user=k,
|
||||
poisson_subsampling_probability=q,
|
||||
)
|
||||
example_eps = _example_privacy(
|
||||
num_epochs,
|
||||
noise_multiplier=z,
|
||||
example_delta=user_delta / (k * math.exp(user_eps)),
|
||||
poisson_subsampling_probability=q,
|
||||
)
|
||||
self.assertAlmostEqual(user_eps, example_eps * k)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue