From 388f46ffa0ee0d6d4573ad304b20c26afbb94fd9 Mon Sep 17 00:00:00 2001 From: Galen Andrew Date: Mon, 20 Sep 2021 17:19:29 -0700 Subject: [PATCH] Adds RdpAccountant: implementation of PrivacyAccountant for RDP. Also adds UnsupportedEventError for handling unsupported events by PrivacyAccountant. PiperOrigin-RevId: 397878895 --- .../privacy/analysis/privacy_accountant.py | 25 +- .../analysis/privacy_accountant_test.py | 101 ++++ .../analysis/rdp_privacy_accountant.py | 572 ++++++++++++++++++ .../analysis/rdp_privacy_accountant_test.py | 307 ++++++++++ 4 files changed, 997 insertions(+), 8 deletions(-) create mode 100644 tensorflow_privacy/privacy/analysis/privacy_accountant_test.py create mode 100644 tensorflow_privacy/privacy/analysis/rdp_privacy_accountant.py create mode 100644 tensorflow_privacy/privacy/analysis/rdp_privacy_accountant_test.py diff --git a/tensorflow_privacy/privacy/analysis/privacy_accountant.py b/tensorflow_privacy/privacy/analysis/privacy_accountant.py index 9235156..578ef0a 100644 --- a/tensorflow_privacy/privacy/analysis/privacy_accountant.py +++ b/tensorflow_privacy/privacy/analysis/privacy_accountant.py @@ -16,8 +16,8 @@ import abc import enum -from tensorflow_privacy.privacy.dp_event import dp_event -from tensorflow_privacy.privacy.dp_event import dp_event_builder +from tensorflow_privacy.privacy.analysis import dp_event +from tensorflow_privacy.privacy.analysis import dp_event_builder class NeighboringRelation(enum.Enum): @@ -25,6 +25,10 @@ class NeighboringRelation(enum.Enum): REPLACE_ONE = 2 +class UnsupportedEventError(Exception): + """Exception to raise if _compose is called on unsupported event type.""" + + class PrivacyAccountant(metaclass=abc.ABCMeta): """Abstract base class for privacy accountants.""" @@ -43,7 +47,7 @@ class PrivacyAccountant(metaclass=abc.ABCMeta): return self._neighboring_relation @abc.abstractmethod - def is_supported(self, event: dp_event.DpEvent) -> bool: + def supports(self, event: dp_event.DpEvent) -> bool: """Checks whether the `DpEvent` can be processed by this accountant. In general this will require recursively checking the structure of the @@ -59,7 +63,7 @@ class PrivacyAccountant(metaclass=abc.ABCMeta): @abc.abstractmethod def _compose(self, event: dp_event.DpEvent, count: int = 1): - """Update internal state to account for application of a `DpEvent`. + """Updates internal state to account for application of a `DpEvent`. Calls to `get_epsilon` or `get_delta` after calling `_compose` will return values that account for this `DpEvent`. @@ -70,7 +74,7 @@ class PrivacyAccountant(metaclass=abc.ABCMeta): """ def compose(self, event: dp_event.DpEvent, count: int = 1): - """Update internal state to account for application of a `DpEvent`. + """Updates internal state to account for application of a `DpEvent`. Calls to `get_epsilon` or `get_delta` after calling `compose` will return values that account for this `DpEvent`. @@ -80,10 +84,15 @@ class PrivacyAccountant(metaclass=abc.ABCMeta): count: The number of times to compose the event. Raises: - TypeError: `event` is not supported by this `PrivacyAccountant`. + UnsupportedEventError: `event` is not supported by this + `PrivacyAccountant`. """ - if not self.is_supported(event): - raise TypeError(f'`DpEvent` {event} is of unsupported type.') + if not isinstance(event, dp_event.DpEvent): + raise TypeError(f'`event` must be `DpEvent`. Found {type(event)}.') + + if not self.supports(event): + raise UnsupportedEventError('Unsupported event: {event}.') + self._ledger.compose(event, count) self._compose(event, count) diff --git a/tensorflow_privacy/privacy/analysis/privacy_accountant_test.py b/tensorflow_privacy/privacy/analysis/privacy_accountant_test.py new file mode 100644 index 0000000..344f3e4 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/privacy_accountant_test.py @@ -0,0 +1,101 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstract base class for tests of `PrivacyAccountant` classes. + +Checks that a class derived from `PrivacyAccountant` has the correct behavior +for standard `DpEvent` classes. +""" + +from typing import Collection + +from absl.testing import absltest + +from tensorflow_privacy.privacy.analysis import dp_event +from tensorflow_privacy.privacy.analysis import privacy_accountant + + +class PrivacyAccountantTest(absltest.TestCase): + + def _make_test_accountants( + self) -> Collection[privacy_accountant.PrivacyAccountant]: + """Makes a list of accountants to test. + + Subclasses should define this to return a list of accountants to be tested. + + Returns: + A list of accountants to test. + """ + return [] + + def test_make_test_accountants(self): + self.assertNotEmpty(self._make_test_accountants()) + + def test_unsupported(self): + + class UnknownDpEvent(dp_event.DpEvent): + pass + + for accountant in self._make_test_accountants(): + for unsupported in [dp_event.UnsupportedDpEvent(), UnknownDpEvent()]: + self.assertFalse(accountant.supports(unsupported)) + self.assertFalse( + accountant.supports(dp_event.SelfComposedDpEvent(unsupported, 10))) + self.assertFalse( + accountant.supports(dp_event.ComposedDpEvent([unsupported]))) + + def test_no_events(self): + for accountant in self._make_test_accountants(): + self.assertEqual(accountant.get_epsilon(1e-12), 0) + self.assertEqual(accountant.get_epsilon(0), 0) + self.assertEqual(accountant.get_epsilon(1), 0) + try: + self.assertEqual(accountant.get_delta(1e-12), 0) + self.assertEqual(accountant.get_delta(0), 0) + self.assertEqual(accountant.get_delta(float('inf')), 0) + except NotImplementedError: + # Implementing `get_delta` is optional. + pass + + def test_no_op(self): + for accountant in self._make_test_accountants(): + event = dp_event.NoOpDpEvent() + self.assertTrue(accountant.supports(event)) + accountant._compose(event) + self.assertEqual(accountant.get_epsilon(1e-12), 0) + self.assertEqual(accountant.get_epsilon(0), 0) + self.assertEqual(accountant.get_epsilon(1), 0) + try: + self.assertEqual(accountant.get_delta(1e-12), 0) + self.assertEqual(accountant.get_delta(0), 0) + self.assertEqual(accountant.get_delta(float('inf')), 0) + except NotImplementedError: + # Implementing `get_delta` is optional. + pass + + def test_non_private(self): + for accountant in self._make_test_accountants(): + event = dp_event.NonPrivateDpEvent() + self.assertTrue(accountant.supports(event)) + accountant._compose(event) + self.assertEqual(accountant.get_epsilon(0.99), float('inf')) + self.assertEqual(accountant.get_epsilon(0), float('inf')) + self.assertEqual(accountant.get_epsilon(1), float('inf')) + try: + self.assertEqual(accountant.get_delta(100), 1) + self.assertEqual(accountant.get_delta(0), 1) + self.assertEqual(accountant.get_delta(float('inf')), 1) + except NotImplementedError: + # Implementing `get_delta` is optional. + pass diff --git a/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant.py new file mode 100644 index 0000000..2bbc327 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant.py @@ -0,0 +1,572 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Privacy accountant that uses Renyi differential privacy.""" + +import math +from typing import Collection, Optional + +import numpy as np +from scipy import special +import six +from tensorflow_privacy.privacy.analysis import dp_event +from tensorflow_privacy.privacy.analysis import privacy_accountant + +NeighborRel = privacy_accountant.NeighboringRelation + + +def _log_add(logx, logy): + """Adds two numbers in the log space.""" + a, b = min(logx, logy), max(logx, logy) + if a == -np.inf: # adding 0 + return b + # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) + return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) + + +def _log_sub(logx, logy): + """Subtracts two numbers in the log space. Answer must be non-negative.""" + if logx < logy: + raise ValueError('The result of subtraction must be non-negative.') + if logy == -np.inf: # subtracting 0 + return logx + if logx == logy: + return -np.inf # 0 is represented as -np.inf in the log space. + + try: + # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). + return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 + except OverflowError: + return logx + + +def _log_sub_sign(logx, logy): + """Returns log(exp(logx)-exp(logy)) and its sign.""" + if logx > logy: + s = True + mag = logx + np.log(1 - np.exp(logy - logx)) + elif logx < logy: + s = False + mag = logy + np.log(1 - np.exp(logx - logy)) + else: + s = True + mag = -np.inf + + return s, mag + + +def _log_comb(n, k): + """Computes log of binomial coefficient.""" + return (special.gammaln(n + 1) - special.gammaln(k + 1) - + special.gammaln(n - k + 1)) + + +def _compute_log_a_int(q, sigma, alpha): + """Computes log(A_alpha) for integer alpha, 0 < q < 1.""" + assert isinstance(alpha, six.integer_types) + + # Initialize with 0 in the log space. + log_a = -np.inf + + for i in range(alpha + 1): + log_coef_i = ( + _log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q)) + + s = log_coef_i + (i * i - i) / (2 * (sigma**2)) + log_a = _log_add(log_a, s) + + return float(log_a) + + +def _compute_log_a_frac(q, sigma, alpha): + """Computes log(A_alpha) for fractional alpha, 0 < q < 1.""" + # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are + # initialized to 0 in the log space: + log_a0, log_a1 = -np.inf, -np.inf + i = 0 + + z0 = sigma**2 * math.log(1 / q - 1) + .5 + + while True: # do ... until loop + coef = special.binom(alpha, i) + log_coef = math.log(abs(coef)) + j = alpha - i + + log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) + log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) + + log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) + log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) + + log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0 + log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1 + + if coef > 0: + log_a0 = _log_add(log_a0, log_s0) + log_a1 = _log_add(log_a1, log_s1) + else: + log_a0 = _log_sub(log_a0, log_s0) + log_a1 = _log_sub(log_a1, log_s1) + + i += 1 + if max(log_s0, log_s1) < -30: + break + + return _log_add(log_a0, log_a1) + + +def _log_erfc(x): + """Computes log(erfc(x)) with high accuracy for large x.""" + try: + return math.log(2) + special.log_ndtr(-x * 2**.5) + except NameError: + # If log_ndtr is not available, approximate as follows: + r = special.erfc(x) + if r == 0.0: + # Using the Laurent series at infinity for the tail of the erfc function: + # erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5) + # To verify in Mathematica: + # Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}] + return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 + + .625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8) + else: + return math.log(r) + + +def _compute_delta(orders, rdp, epsilon): + """Compute delta given a list of RDP values and target epsilon. + + Args: + orders: An array of orders. + rdp: An array of RDP guarantees. + epsilon: The target epsilon. + + Returns: + Optimal delta. + + Raises: + ValueError: If input is malformed. + + """ + if epsilon < 0: + raise ValueError(f'Epsilon cannot be negative. Found {epsilon}.') + if len(orders) != len(rdp): + raise ValueError('Input lists must have the same length.') + + # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): + # delta = min( np.exp((rdp - epsilon) * (orders - 1)) ) + + # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4): + logdeltas = [] # work in log space to avoid overflows + for (a, r) in zip(orders, rdp): + if a < 1: + raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.') + if r < 0: + raise ValueError(f'Renyi divergence cannot be negative. Found {r}.') + # For small alpha, we are better of with bound via KL divergence: + # delta <= sqrt(1-exp(-KL)). + # Take a min of the two bounds. + if r == 0: + logdelta = -np.inf + else: + logdelta = 0.5 * math.log1p(-math.exp(-r)) + if a > 1.01: + # This bound is not numerically stable as alpha->1. + # Thus we have a min value for alpha. + # The bound is also not useful for small alpha, so doesn't matter. + rdp_bound = (a - 1) * (r - epsilon + math.log1p(-1 / a)) - math.log(a) + logdelta = min(logdelta, rdp_bound) + + logdeltas.append(logdelta) + + return min(math.exp(np.min(logdeltas)), 1.) + + +def _compute_epsilon(orders, rdp, delta): + """Compute epsilon given a list of RDP values and target delta. + + Args: + orders: An array of orders. + rdp: An array of RDP guarantees. + delta: The target delta. Must be >= 0. + + Returns: + Optimal epsilon. + + Raises: + ValueError: If input is malformed. + + """ + if delta < 0: + raise ValueError(f'Delta cannot be negative. Found {delta}.') + + if delta == 0: + if all(r == 0 for r in rdp): + return 0 + else: + return np.inf + + if len(orders) != len(rdp): + raise ValueError('Input lists must have the same length.') + + # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): + # epsilon = min( rdp - math.log(delta) / (orders - 1) ) + + # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4). + # Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1). + eps = [] + for (a, r) in zip(orders, rdp): + if a < 1: + raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.') + if r < 0: + raise ValueError(f'Renyi divergence cannot be negative. Found {r}.') + + if delta**2 + math.expm1(-r) > 0: + # In this case, we can simply bound via KL divergence: + # delta <= sqrt(1-exp(-KL)). + epsilon = 0 # No need to try further computation if we have epsilon = 0. + elif a > 1.01: + # This bound is not numerically stable as alpha->1. + # Thus we have a min value of alpha. + # The bound is also not useful for small alpha, so doesn't matter. + epsilon = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1) + else: + # In this case we can't do anything. E.g., asking for delta = 0. + epsilon = np.inf + eps.append(epsilon) + + return max(0, np.min(eps)) + + +def _stable_inplace_diff_in_log(vec, signs, n=-1): + """Replaces the first n-1 dims of vec with the log of abs difference operator. + + Args: + vec: numpy array of floats with size larger than 'n' + signs: Optional numpy array of bools with the same size as vec in case one + needs to compute partial differences vec and signs jointly describe a + vector of real numbers' sign and abs in log scale. + n: Optonal upper bound on number of differences to compute. If negative, all + differences are computed. + + Returns: + The first n-1 dimension of vec and signs will store the log-abs and sign of + the difference. + + Raises: + ValueError: If input is malformed. + """ + + assert vec.shape == signs.shape + if n < 0: + n = np.max(vec.shape) - 1 + else: + assert np.max(vec.shape) >= n + 1 + for j in range(0, n, 1): + if signs[j] == signs[j + 1]: # When the signs are the same + # if the signs are both positive, then we can just use the standard one + signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j]) + # otherwise, we do that but toggle the sign + if not signs[j + 1]: + signs[j] = ~signs[j] + else: # When the signs are different. + vec[j] = _log_add(vec[j], vec[j + 1]) + signs[j] = signs[j + 1] + + +def _get_forward_diffs(fun, n): + """Computes up to nth order forward difference evaluated at 0. + + See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf + + Args: + fun: Function to compute forward differences of. + n: Number of differences to compute. + + Returns: + Pair (deltas, signs_deltas) of the log deltas and their signs. + """ + func_vec = np.zeros(n + 3) + signs_func_vec = np.ones(n + 3, dtype=bool) + + # ith coordinate of deltas stores log(abs(ith order discrete derivative)) + deltas = np.zeros(n + 2) + signs_deltas = np.zeros(n + 2, dtype=bool) + for i in range(1, n + 3, 1): + func_vec[i] = fun(1.0 * (i - 1)) + for i in range(0, n + 2, 1): + # Diff in log scale + _stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i) + deltas[i] = func_vec[0] + signs_deltas[i] = signs_func_vec[0] + return deltas, signs_deltas + + +def _compute_log_a(q, noise_multiplier, alpha): + if float(alpha).is_integer(): + return _compute_log_a_int(q, noise_multiplier, int(alpha)) + else: + return _compute_log_a_frac(q, noise_multiplier, alpha) + + +def _compute_rdp_poisson_subsampled_gaussian(q, noise_multiplier, orders): + """Computes RDP of the Poisson sampled Gaussian mechanism. + + Args: + q: The sampling rate. + noise_multiplier: The ratio of the standard deviation of the Gaussian noise + to the l2-sensitivity of the function to which it is added. + orders: An array of RDP orders. + + Returns: + The RDPs at all orders. Can be `np.inf`. + """ + + def compute_one_order(q, alpha): + if np.isinf(alpha) or noise_multiplier == 0: + return np.inf + + if q == 0: + return 0 + + if q == 1.: + return alpha / (2 * noise_multiplier**2) + + return _compute_log_a(q, noise_multiplier, alpha) / (alpha - 1) + + return np.array([compute_one_order(q, order) for order in orders]) + + +def _compute_rdp_sample_wor_gaussian(q, noise_multiplier, orders): + """Computes RDP of Gaussian mechanism using sampling without replacement. + + This function applies to the following schemes: + 1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n. + 2. ``Replace one data point'' version of differential privacy, i.e., n is + considered public information. + + Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened + version applies subsampled-Gaussian mechanism.) + - Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and + Analytical Moments Accountant." AISTATS'2019. + + Args: + q: The sampling proportion = m / n. Assume m is an integer <= n. + noise_multiplier: The ratio of the standard deviation of the Gaussian noise + to the l2-sensitivity of the function to which it is added. + orders: An array of RDP orders. + + Returns: + The RDPs at all orders, can be np.inf. + """ + return np.array([ + _compute_rdp_sample_wor_gaussian_scalar(q, noise_multiplier, order) + for order in orders + ]) + + +def _compute_rdp_sample_wor_gaussian_scalar(q, sigma, alpha): + """Compute RDP of the Sampled Gaussian mechanism at order alpha. + + Args: + q: The sampling proportion = m / n. Assume m is an integer <= n. + sigma: The std of the additive Gaussian noise. + alpha: The order at which RDP is computed. + + Returns: + RDP at alpha, can be np.inf. + """ + + assert (q <= 1) and (q >= 0) and (alpha >= 1) + + if q == 0: + return 0 + + if q == 1.: + return alpha / (2 * sigma**2) + + if np.isinf(alpha): + return np.inf + + if float(alpha).is_integer(): + return _compute_rdp_sample_wor_gaussian_int(q, sigma, int(alpha)) / ( + alpha - 1) + else: + # When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate + # the CGF and obtain an upper bound + alpha_f = math.floor(alpha) + alpha_c = math.ceil(alpha) + + x = _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha_f) + y = _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha_c) + t = alpha - alpha_f + return ((1 - t) * x + t * y) / (alpha - 1) + + +def _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha): + """Compute log(A_alpha) for integer alpha, subsampling without replacement. + + When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly, + otherwise compute the bound with Stirling approximation. + + Args: + q: The sampling proportion = m / n. Assume m is an integer <= n. + sigma: The std of the additive Gaussian noise. + alpha: The order at which RDP is computed. + + Returns: + RDP at alpha, can be np.inf. + """ + + max_alpha = 256 + assert isinstance(alpha, six.integer_types) + + if np.isinf(alpha): + return np.inf + elif alpha == 1: + return 0 + + def cgf(x): + # Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2) + return x * 1.0 * (x + 1) / (2.0 * sigma**2) + + def func(x): + # Return the rdp of Gaussian mechanism + return 1.0 * x / (2.0 * sigma**2) + + # Initialize with 1 in the log space. + log_a = 0 + # Calculates the log term when alpha = 2 + log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0))) + if alpha <= max_alpha: + # We need forward differences of exp(cgf) + # The following line is the numerically stable way of implementing it. + # The output is in polar form with logarithmic magnitude + deltas, _ = _get_forward_diffs(cgf, alpha) + # Compute the bound exactly requires book keeping of O(alpha**2) + + for i in range(2, alpha + 1): + if i == 2: + s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum( + np.log(4) + log_f2m1, + func(2.0) + np.log(2)) + elif i > 2: + delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1] + delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1] + s = np.log(4) + 0.5 * (delta_lo + delta_hi) + s = np.minimum(s, np.log(2) + cgf(i - 1)) + s += i * np.log(q) + _log_comb(alpha, i) + log_a = _log_add(log_a, s) + return float(log_a) + else: + # Compute the bound with stirling approximation. Everything is O(x) now. + for i in range(2, alpha + 1): + if i == 2: + s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum( + np.log(4) + log_f2m1, + func(2.0) + np.log(2)) + else: + s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i) + log_a = _log_add(log_a, s) + + return log_a + + +class RdpAccountant(privacy_accountant.PrivacyAccountant): + """Privacy accountant that uses Renyi differential privacy.""" + + def __init__( + self, + orders: Optional[Collection[float]] = None, + neighboring_relation: NeighborRel = NeighborRel.ADD_OR_REMOVE_ONE, + ): + super(RdpAccountant, self).__init__(neighboring_relation) + if orders is None: + # Default orders chosen to give good coverage for Gaussian mechanism in + # the privacy regime of interest. In the future, more orders might be + # added, in particular, fractional orders between 1.0 and 10.0 or so. + orders = [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 20, 24, 28, 32, 48, 64, 128, + 256, 512, 1024 + ] + self._orders = np.array(orders) + self._rdp = np.zeros_like(orders, dtype=np.float64) + + def supports(self, event: dp_event.DpEvent) -> bool: + return self._maybe_compose(event, 0, False) + + def _compose(self, event: dp_event.DpEvent, count: int = 1): + self._maybe_compose(event, count, True) + + def _maybe_compose(self, event: dp_event.DpEvent, count: int, + do_compose: bool) -> bool: + """Traverses `event` and performs composition if `do_compose` is True. + + If `do_compose` is False, can be used to check whether composition is + supported. + + Args: + event: A `DpEvent` to process. + count: The number of times to compose the event. + do_compose: Whether to actually perform the composition. + + Returns: + True if event is supported, otherwise False. + """ + + if isinstance(event, dp_event.NoOpDpEvent): + return True + elif isinstance(event, dp_event.NonPrivateDpEvent): + if do_compose: + self._rdp += np.inf + return True + elif isinstance(event, dp_event.SelfComposedDpEvent): + return self._maybe_compose(event.event, event.count * count, do_compose) + elif isinstance(event, dp_event.ComposedDpEvent): + return all( + self._maybe_compose(e, count, do_compose) for e in event.events) + elif isinstance(event, dp_event.GaussianDpEvent): + if do_compose: + self._rdp += count * _compute_rdp_poisson_subsampled_gaussian( + q=1.0, noise_multiplier=event.noise_multiplier, orders=self._orders) + return True + elif isinstance(event, dp_event.PoissonSampledDpEvent): + if (self._neighboring_relation is not NeighborRel.ADD_OR_REMOVE_ONE or + not isinstance(event.event, dp_event.GaussianDpEvent)): + return False + if do_compose: + self._rdp += count * _compute_rdp_poisson_subsampled_gaussian( + q=event.sampling_probability, + noise_multiplier=event.event.noise_multiplier, + orders=self._orders) + return True + elif isinstance(event, dp_event.FixedBatchSampledWorDpEvent): + if (self._neighboring_relation is not NeighborRel.REPLACE_ONE or + not isinstance(event.event, dp_event.GaussianDpEvent)): + return False + if do_compose: + self._rdp += count * _compute_rdp_sample_wor_gaussian( + q=event.batch_size / event.dataset_size, + noise_multiplier=event.event.noise_multiplier, + orders=self._orders) + return True + else: + # Unsupported event (including `UnsupportedDpEvent`). + return False + + def get_epsilon(self, target_delta: float) -> float: + return _compute_epsilon(self._orders, self._rdp, target_delta) + + def get_delta(self, target_epsilon: float) -> float: + return _compute_delta(self._orders, self._rdp, target_epsilon) diff --git a/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant_test.py new file mode 100644 index 0000000..817d41c --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/rdp_privacy_accountant_test.py @@ -0,0 +1,307 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rdp_privacy_accountant.""" + +import math +import sys + +from absl.testing import absltest +from absl.testing import parameterized +import mpmath +import numpy as np + +from tensorflow_privacy.privacy.analysis import dp_event +from tensorflow_privacy.privacy.analysis import privacy_accountant +from tensorflow_privacy.privacy.analysis import privacy_accountant_test +from tensorflow_privacy.privacy.analysis import rdp_privacy_accountant + + +def _get_test_rdp(event, count=1): + accountant = rdp_privacy_accountant.RdpAccountant(orders=[2.71828]) + accountant.compose(event, count) + return accountant._rdp[0] + + +def _log_float_mp(x): + # Convert multi-precision input to float log space. + if x >= sys.float_info.min: + return float(mpmath.log(x)) + else: + return -np.inf + + +def _compute_a_mp(sigma, q, alpha): + """Compute A_alpha for arbitrary alpha by numerical integration.""" + + def mu0(x): + return mpmath.npdf(x, mu=0, sigma=sigma) + + def _mu_over_mu0(x, q, sigma): + return (1 - q) + q * mpmath.exp((2 * x - 1) / (2 * sigma**2)) + + def a_alpha_fn(z): + return mu0(z) * _mu_over_mu0(z, q, sigma)**alpha + + bounds = (-mpmath.inf, mpmath.inf) + a_alpha, _ = mpmath.quad(a_alpha_fn, bounds, error=True, maxdegree=8) + return a_alpha + + +class RdpPrivacyAccountantTest(privacy_accountant_test.PrivacyAccountantTest, + parameterized.TestCase): + + def _make_test_accountants(self): + return [ + rdp_privacy_accountant.RdpAccountant( + [2.0], privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE), + rdp_privacy_accountant.RdpAccountant( + [2.0], privacy_accountant.NeighboringRelation.REPLACE_ONE) + ] + + def test_supports(self): + aor_accountant = rdp_privacy_accountant.RdpAccountant( + [2.0], privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE) + ro_accountant = rdp_privacy_accountant.RdpAccountant( + [2.0], privacy_accountant.NeighboringRelation.REPLACE_ONE) + + event = dp_event.GaussianDpEvent(1.0) + self.assertTrue(aor_accountant.supports(event)) + self.assertTrue(ro_accountant.supports(event)) + + event = dp_event.SelfComposedDpEvent(dp_event.GaussianDpEvent(1.0), 6) + self.assertTrue(aor_accountant.supports(event)) + self.assertTrue(ro_accountant.supports(event)) + + event = dp_event.ComposedDpEvent( + [dp_event.GaussianDpEvent(1.0), + dp_event.GaussianDpEvent(2.0)]) + self.assertTrue(aor_accountant.supports(event)) + self.assertTrue(ro_accountant.supports(event)) + + event = dp_event.PoissonSampledDpEvent(0.1, dp_event.GaussianDpEvent(1.0)) + self.assertTrue(aor_accountant.supports(event)) + self.assertFalse(ro_accountant.supports(event)) + + event = dp_event.FixedBatchSampledWorDpEvent(1000, 10, + dp_event.GaussianDpEvent(1.0)) + self.assertFalse(aor_accountant.supports(event)) + self.assertTrue(ro_accountant.supports(event)) + + event = dp_event.FixedBatchSampledWrDpEvent(1000, 10, + dp_event.GaussianDpEvent(1.0)) + self.assertFalse(aor_accountant.supports(event)) + self.assertFalse(ro_accountant.supports(event)) + + def test_rdp_composition(self): + base_event = dp_event.GaussianDpEvent(3.14159) + base_rdp = _get_test_rdp(base_event) + + rdp_with_count = _get_test_rdp(base_event, count=6) + self.assertAlmostEqual(rdp_with_count, base_rdp * 6) + + rdp_with_self_compose = _get_test_rdp( + dp_event.SelfComposedDpEvent(base_event, 6)) + self.assertAlmostEqual(rdp_with_self_compose, base_rdp * 6) + + rdp_with_self_compose_and_count = _get_test_rdp( + dp_event.SelfComposedDpEvent(base_event, 2), count=3) + self.assertAlmostEqual(rdp_with_self_compose_and_count, base_rdp * 6) + + rdp_with_compose = _get_test_rdp(dp_event.ComposedDpEvent([base_event] * 6)) + self.assertAlmostEqual(rdp_with_compose, base_rdp * 6) + + rdp_with_compose_and_self_compose = _get_test_rdp( + dp_event.ComposedDpEvent([ + dp_event.SelfComposedDpEvent(base_event, 1), + dp_event.SelfComposedDpEvent(base_event, 2), + dp_event.SelfComposedDpEvent(base_event, 3) + ])) + self.assertAlmostEqual(rdp_with_compose_and_self_compose, base_rdp * 6) + + base_event_2 = dp_event.GaussianDpEvent(1.61803) + base_rdp_2 = _get_test_rdp(base_event_2) + rdp_with_heterogeneous_compose = _get_test_rdp( + dp_event.ComposedDpEvent([base_event, base_event_2])) + self.assertAlmostEqual(rdp_with_heterogeneous_compose, + base_rdp + base_rdp_2) + + def test_zero_poisson_sample(self): + accountant = rdp_privacy_accountant.RdpAccountant([3.14159]) + accountant.compose( + dp_event.PoissonSampledDpEvent(0, dp_event.GaussianDpEvent(1.0))) + self.assertEqual(accountant.get_epsilon(1e-10), 0) + self.assertEqual(accountant.get_delta(1e-10), 0) + + def test_zero_fixed_batch_sample(self): + accountant = rdp_privacy_accountant.RdpAccountant( + [3.14159], privacy_accountant.NeighboringRelation.REPLACE_ONE) + accountant.compose( + dp_event.FixedBatchSampledWorDpEvent(1000, 0, + dp_event.GaussianDpEvent(1.0))) + self.assertEqual(accountant.get_epsilon(1e-10), 0) + self.assertEqual(accountant.get_delta(1e-10), 0) + + def test_epsilon_non_private_gaussian(self): + accountant = rdp_privacy_accountant.RdpAccountant([3.14159]) + accountant.compose(dp_event.GaussianDpEvent(0)) + self.assertEqual(accountant.get_epsilon(1e-1), np.inf) + + def test_compute_rdp_gaussian(self): + alpha = 3.14159 + sigma = 2.71828 + event = dp_event.GaussianDpEvent(sigma) + accountant = rdp_privacy_accountant.RdpAccountant(orders=[alpha]) + accountant.compose(event) + self.assertAlmostEqual(accountant._rdp[0], alpha / (2 * sigma**2)) + + def test_compute_rdp_poisson_sampled_gaussian(self): + orders = [1.5, 2.5, 5, 50, 100, np.inf] + noise_multiplier = 2.5 + sampling_probability = 0.01 + count = 50 + event = dp_event.SelfComposedDpEvent( + dp_event.PoissonSampledDpEvent( + sampling_probability, dp_event.GaussianDpEvent(noise_multiplier)), + count) + accountant = rdp_privacy_accountant.RdpAccountant(orders=orders) + accountant.compose(event) + self.assertTrue( + np.allclose( + accountant._rdp, [ + 6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02, + np.inf + ], + rtol=1e-4)) + + def test_compute_epsilon_delta_pure_dp(self): + orders = range(2, 33) + rdp = [1.1 for o in orders] # Constant corresponds to pure DP. + + epsilon = rdp_privacy_accountant._compute_epsilon(orders, rdp, delta=1e-5) + # Compare with epsilon computed by hand. + self.assertAlmostEqual(epsilon, 1.32783806176) + + delta = rdp_privacy_accountant._compute_delta( + orders, rdp, epsilon=1.32783806176) + self.assertAlmostEqual(delta, 1e-5) + + def test_compute_epsilon_delta_gaussian(self): + orders = [0.001 * i for i in range(1000, 100000)] + + # noise multiplier is chosen to obtain exactly (1,1e-6)-DP. + rdp = rdp_privacy_accountant._compute_rdp_poisson_subsampled_gaussian( + 1, 4.530877117, orders) + + eps = rdp_privacy_accountant._compute_epsilon(orders, rdp, delta=1e-6) + self.assertAlmostEqual(eps, 1) + + delta = rdp_privacy_accountant._compute_delta(orders, rdp, epsilon=1) + self.assertAlmostEqual(delta, 1e-6) + + params = ({ + 'q': 1e-7, + 'sigma': .1, + 'order': 1.01 + }, { + 'q': 1e-6, + 'sigma': .1, + 'order': 256 + }, { + 'q': 1e-5, + 'sigma': .1, + 'order': 256.1 + }, { + 'q': 1e-6, + 'sigma': 1, + 'order': 27 + }, { + 'q': 1e-4, + 'sigma': 1., + 'order': 1.5 + }, { + 'q': 1e-3, + 'sigma': 1., + 'order': 2 + }, { + 'q': .01, + 'sigma': 10, + 'order': 20 + }, { + 'q': .1, + 'sigma': 100, + 'order': 20.5 + }, { + 'q': .99, + 'sigma': .1, + 'order': 256 + }, { + 'q': .999, + 'sigma': 100, + 'order': 256.1 + }) + + # pylint:disable=undefined-variable + @parameterized.parameters(p for p in params) + def test_compute_log_a_equals_mp(self, q, sigma, order): + # Compare the cheap computation of log(A) with an expensive, multi-precision + # computation. + log_a = rdp_privacy_accountant._compute_log_a(q, sigma, order) + log_a_mp = _log_float_mp(_compute_a_mp(sigma, q, order)) + np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4) + + def test_delta_bounds_gaussian(self): + # Compare the optimal bound for Gaussian with the one derived from RDP. + # Also compare the RDP upper bound with the "standard" upper bound. + orders = [0.1 * x for x in range(10, 505)] + eps_vec = [0.1 * x for x in range(500)] + rdp = rdp_privacy_accountant._compute_rdp_poisson_subsampled_gaussian( + 1, 1, orders) + for eps in eps_vec: + delta = rdp_privacy_accountant._compute_delta(orders, rdp, epsilon=eps) + # For comparison, we compute the optimal guarantee for Gaussian + # using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2). + delta0 = math.erfc((eps - .5) / math.sqrt(2)) / 2 + delta0 = delta0 - math.exp(eps) * math.erfc((eps + .5) / math.sqrt(2)) / 2 + self.assertLessEqual(delta0, delta + 1e-300) # need tolerance 10^-300 + + # Compute the "standard" upper bound, which should be an upper bound. + # Note, if orders is too sparse, this will NOT be an upper bound. + if eps >= 0.5: + delta1 = math.exp(-0.5 * (eps - 0.5)**2) + else: + delta1 = 1 + self.assertLessEqual(delta, delta1 + 1e-300) + + def test_epsilon_delta_consistency(self): + orders = range(2, 50) # Large range of orders (helps test for overflows). + for q in [0, 0.01, 0.1, 0.8, 1.]: + for multiplier in [0.0, 0.1, 1., 10., 100.]: + event = dp_event.PoissonSampledDpEvent( + q, dp_event.GaussianDpEvent(multiplier)) + accountant = rdp_privacy_accountant.RdpAccountant(orders) + accountant.compose(event) + for delta in [.99, .9, .1, .01, 1e-3, 1e-5, 1e-9, 1e-12]: + epsilon = accountant.get_epsilon(delta) + delta2 = accountant.get_delta(epsilon) + if np.isposinf(epsilon): + self.assertEqual(delta2, 1.0) + elif epsilon == 0: + self.assertLessEqual(delta2, delta) + else: + self.assertAlmostEqual(delta, delta2) + + +if __name__ == '__main__': + absltest.main()