diff --git a/tensorflow_privacy/privacy/analysis/BUILD b/tensorflow_privacy/privacy/analysis/BUILD index 63abc69..6e72cb5 100644 --- a/tensorflow_privacy/privacy/analysis/BUILD +++ b/tensorflow_privacy/privacy/analysis/BUILD @@ -61,6 +61,11 @@ py_library( srcs = ["rdp_accountant.py"], srcs_version = "PY3", visibility = ["//visibility:public"], + deps = [ + "@com_google_differential_py//python/dp_accounting:dp_event", + "@com_google_differential_py//python/dp_accounting:privacy_accountant", + "@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant", + ], ) py_test( diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py index 2db5e9e..f35e442 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""RDP analysis of the Sampled Gaussian Mechanism. +"""(Deprecated) RDP analysis of the Sampled Gaussian Mechanism. + +The functions in this package have been superseded by more general accounting +mechanisms in Google's `differential_privacy` package. These functions may at +some future date be removed. Functionality for computing Renyi differential privacy (RDP) of an additive Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods: @@ -37,342 +41,50 @@ The example code would be: eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta) """ -import math -import sys - import numpy as np -from scipy import special -######################## -# LOG-SPACE ARITHMETIC # -######################## +from com_google_differential_py.python.dp_accounting import dp_event +from com_google_differential_py.python.dp_accounting import privacy_accountant +from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant -def _log_add(logx, logy): - """Add two numbers in the log space.""" - a, b = min(logx, logy), max(logx, logy) - if a == -np.inf: # adding 0 - return b - # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) - return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) - - -def _log_sub(logx, logy): - """Subtract two numbers in the log space. Answer must be non-negative.""" - if logx < logy: - raise ValueError("The result of subtraction must be non-negative.") - if logy == -np.inf: # subtracting 0 - return logx - if logx == logy: - return -np.inf # 0 is represented as -np.inf in the log space. - - try: - # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). - return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 - except OverflowError: - return logx - - -def _log_sub_sign(logx, logy): - """Returns log(exp(logx)-exp(logy)) and its sign.""" - if logx > logy: - s = True - mag = logx + np.log(1 - np.exp(logy - logx)) - elif logx < logy: - s = False - mag = logy + np.log(1 - np.exp(logx - logy)) - else: - s = True - mag = -np.inf - - return s, mag - - -def _log_print(logx): - """Pretty print.""" - if logx < math.log(sys.float_info.max): - return "{}".format(math.exp(logx)) - else: - return "exp({})".format(logx) - - -def _log_comb(n, k): - return (special.gammaln(n + 1) - special.gammaln(k + 1) - - special.gammaln(n - k + 1)) - - -def _compute_log_a_int(q, sigma, alpha): - """Compute log(A_alpha) for integer alpha. 0 < q < 1.""" - assert isinstance(alpha, int) - - # Initialize with 0 in the log space. - log_a = -np.inf - - for i in range(alpha + 1): - log_coef_i = ( - _log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q)) - - s = log_coef_i + (i * i - i) / (2 * (sigma**2)) - log_a = _log_add(log_a, s) - - return float(log_a) - - -def _compute_log_a_frac(q, sigma, alpha): - """Compute log(A_alpha) for fractional alpha. 0 < q < 1.""" - # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are - # initialized to 0 in the log space: - log_a0, log_a1 = -np.inf, -np.inf - i = 0 - - z0 = sigma**2 * math.log(1 / q - 1) + .5 - - while True: # do ... until loop - coef = special.binom(alpha, i) - log_coef = math.log(abs(coef)) - j = alpha - i - - log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) - log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) - - log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) - log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) - - log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0 - log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1 - - if coef > 0: - log_a0 = _log_add(log_a0, log_s0) - log_a1 = _log_add(log_a1, log_s1) - else: - log_a0 = _log_sub(log_a0, log_s0) - log_a1 = _log_sub(log_a1, log_s1) - - i += 1 - if max(log_s0, log_s1) < -30: - break - - return _log_add(log_a0, log_a1) - - -def _compute_log_a(q, sigma, alpha): - """Compute log(A_alpha) for any positive finite alpha.""" - if float(alpha).is_integer(): - return _compute_log_a_int(q, sigma, int(alpha)) - else: - return _compute_log_a_frac(q, sigma, alpha) - - -def _log_erfc(x): - """Compute log(erfc(x)) with high accuracy for large x.""" - try: - return math.log(2) + special.log_ndtr(-x * 2**.5) - except NameError: - # If log_ndtr is not available, approximate as follows: - r = special.erfc(x) - if r == 0.0: - # Using the Laurent series at infinity for the tail of the erfc function: - # erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5) - # To verify in Mathematica: - # Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}] - return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 + - .625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8) - else: - return math.log(r) - - -def _compute_delta(orders, rdp, eps): - """Compute delta given a list of RDP values and target epsilon. +def _compute_rdp_from_event(orders, event, count): + """Computes RDP from a DpEvent using RdpAccountant. Args: - orders: An array (or a scalar) of orders. - rdp: A list (or a scalar) of RDP guarantees. - eps: The target epsilon. + orders: An array (or a scalar) of RDP orders. + event: A DpEvent to compute the RDP of. + count: The number of self-compositions. Returns: - Pair of (delta, optimal_order). - - Raises: - ValueError: If input is malformed. - + The RDP at all orders. Can be `np.inf`. """ orders_vec = np.atleast_1d(orders) - rdp_vec = np.atleast_1d(rdp) - if eps < 0: - raise ValueError("Value of privacy loss bound epsilon must be >=0.") - if len(orders_vec) != len(rdp_vec): - raise ValueError("Input lists must have the same length.") - - # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): - # delta = min( np.exp((rdp_vec - eps) * (orders_vec - 1)) ) - - # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4): - logdeltas = [] # work in log space to avoid overflows - for (a, r) in zip(orders_vec, rdp_vec): - if a < 1: - raise ValueError("Renyi divergence order must be >=1.") - if r < 0: - raise ValueError("Renyi divergence must be >=0.") - # For small alpha, we are better of with bound via KL divergence: - # delta <= sqrt(1-exp(-KL)). - # Take a min of the two bounds. - logdelta = 0.5 * math.log1p(-math.exp(-r)) - if a > 1.01: - # This bound is not numerically stable as alpha->1. - # Thus we have a min value for alpha. - # The bound is also not useful for small alpha, so doesn't matter. - rdp_bound = (a - 1) * (r - eps + math.log1p(-1 / a)) - math.log(a) - logdelta = min(logdelta, rdp_bound) - - logdeltas.append(logdelta) - - idx_opt = np.argmin(logdeltas) - return min(math.exp(logdeltas[idx_opt]), 1.), orders_vec[idx_opt] - - -def _compute_eps(orders, rdp, delta): - """Compute epsilon given a list of RDP values and target delta. - - Args: - orders: An array (or a scalar) of orders. - rdp: A list (or a scalar) of RDP guarantees. - delta: The target delta. - - Returns: - Pair of (eps, optimal_order). - - Raises: - ValueError: If input is malformed. - - """ - orders_vec = np.atleast_1d(orders) - rdp_vec = np.atleast_1d(rdp) - - if delta <= 0: - raise ValueError("Privacy failure probability bound delta must be >0.") - if len(orders_vec) != len(rdp_vec): - raise ValueError("Input lists must have the same length.") - - # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): - # eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) ) - - # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4). - # Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1). - eps_vec = [] - for (a, r) in zip(orders_vec, rdp_vec): - if a < 1: - raise ValueError("Renyi divergence order must be >=1.") - if r < 0: - raise ValueError("Renyi divergence must be >=0.") - - if delta**2 + math.expm1(-r) >= 0: - # In this case, we can simply bound via KL divergence: - # delta <= sqrt(1-exp(-KL)). - eps = 0 # No need to try further computation if we have eps = 0. - elif a > 1.01: - # This bound is not numerically stable as alpha->1. - # Thus we have a min value of alpha. - # The bound is also not useful for small alpha, so doesn't matter. - eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1) - else: - # In this case we can't do anything. E.g., asking for delta = 0. - eps = np.inf - eps_vec.append(eps) - - idx_opt = np.argmin(eps_vec) - return max(0, eps_vec[idx_opt]), orders_vec[idx_opt] - - -def _stable_inplace_diff_in_log(vec, signs, n=-1): - """Replaces the first n-1 dims of vec with the log of abs difference operator. - - Args: - vec: numpy array of floats with size larger than 'n' - signs: Optional numpy array of bools with the same size as vec in case one - needs to compute partial differences vec and signs jointly describe a - vector of real numbers' sign and abs in log scale. - n: Optonal upper bound on number of differences to compute. If negative, all - differences are computed. - - Returns: - The first n-1 dimension of vec and signs will store the log-abs and sign of - the difference. - - Raises: - ValueError: If input is malformed. - """ - - assert vec.shape == signs.shape - if n < 0: - n = np.max(vec.shape) - 1 + if isinstance(event, dp_event.SampledWithoutReplacementDpEvent): + neighboring_relation = privacy_accountant.NeighboringRelation.REPLACE_ONE + elif isinstance(event, dp_event.SingleEpochTreeAggregationDpEvent): + neighboring_relation = privacy_accountant.NeighboringRelation.REPLACE_SPECIAL else: - assert np.max(vec.shape) >= n + 1 - for j in range(0, n, 1): - if signs[j] == signs[j + 1]: # When the signs are the same - # if the signs are both positive, then we can just use the standard one - signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j]) - # otherwise, we do that but toggle the sign - if not signs[j + 1]: - signs[j] = ~signs[j] - else: # When the signs are different. - vec[j] = _log_add(vec[j], vec[j + 1]) - signs[j] = signs[j + 1] + neighboring_relation = privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE + accountant = rdp_privacy_accountant.RdpAccountant(orders_vec, + neighboring_relation) + accountant.compose(event, count) + rdp = accountant._rdp # pylint: disable=protected-access -def _get_forward_diffs(fun, n): - """Computes up to nth order forward difference evaluated at 0. - - See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf - - Args: - fun: Function to compute forward differences of. - n: Number of differences to compute. - - Returns: - Pair (deltas, signs_deltas) of the log deltas and their signs. - """ - func_vec = np.zeros(n + 3) - signs_func_vec = np.ones(n + 3, dtype=bool) - - # ith coordinate of deltas stores log(abs(ith order discrete derivative)) - deltas = np.zeros(n + 2) - signs_deltas = np.zeros(n + 2, dtype=bool) - for i in range(1, n + 3, 1): - func_vec[i] = fun(1.0 * (i - 1)) - for i in range(0, n + 2, 1): - # Diff in log scale - _stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i) - deltas[i] = func_vec[0] - signs_deltas[i] = signs_func_vec[0] - return deltas, signs_deltas - - -def _compute_rdp(q, sigma, alpha): - """Compute RDP of the Sampled Gaussian mechanism at order alpha. - - Args: - q: The sampling rate. - sigma: The std of the additive Gaussian noise. - alpha: The order at which RDP is computed. - - Returns: - RDP at alpha, can be np.inf. - """ - if q == 0: - return 0 - - if q == 1.: - return alpha / (2 * sigma**2) - - if np.isinf(alpha): - return np.inf - - return _compute_log_a(q, sigma, alpha) / (alpha - 1) + if np.isscalar(orders): + return rdp[0] + else: + return rdp def compute_rdp(q, noise_multiplier, steps, orders): - """Computes RDP of the Sampled Gaussian Mechanism. + """(Deprecated) Computes RDP of the Sampled Gaussian Mechanism. + + This function has been superseded by more general accounting mechanisms in + Google's `differential_privacy` package. It may at some future date be + removed. Args: q: The sampling rate. @@ -384,17 +96,18 @@ def compute_rdp(q, noise_multiplier, steps, orders): Returns: The RDPs at all orders. Can be `np.inf`. """ - if np.isscalar(orders): - rdp = _compute_rdp(q, noise_multiplier, orders) - else: - rdp = np.array( - [_compute_rdp(q, noise_multiplier, order) for order in orders]) + event = dp_event.PoissonSampledDpEvent( + q, dp_event.GaussianDpEvent(noise_multiplier)) - return rdp * steps + return _compute_rdp_from_event(orders, event, steps) def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders): - """Compute RDP of Gaussian Mechanism using sampling without replacement. + """(Deprecated) Compute RDP of Gaussian Mechanism sampling w/o replacement. + + This function has been superseded by more general accounting mechanisms in + Google's `differential_privacy` package. It may at some future date be + removed. This function applies to the following schemes: 1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n. @@ -416,129 +129,19 @@ def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders): Returns: The RDPs at all orders, can be np.inf. """ - if np.isscalar(orders): - rdp = _compute_rdp_sample_without_replacement_scalar( - q, noise_multiplier, orders) - else: - rdp = np.array([ - _compute_rdp_sample_without_replacement_scalar(q, noise_multiplier, - order) - for order in orders - ]) + event = dp_event.SampledWithoutReplacementDpEvent( + 1, q, dp_event.GaussianDpEvent(noise_multiplier)) - return rdp * steps - - -def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha): - """Compute RDP of the Sampled Gaussian mechanism at order alpha. - - Args: - q: The sampling proportion = m / n. Assume m is an integer <= n. - sigma: The std of the additive Gaussian noise. - alpha: The order at which RDP is computed. - - Returns: - RDP at alpha, can be np.inf. - """ - - assert (q <= 1) and (q >= 0) and (alpha >= 1) - - if q == 0: - return 0 - - if q == 1.: - return alpha / (2 * sigma**2) - - if np.isinf(alpha): - return np.inf - - if float(alpha).is_integer(): - return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / ( - alpha - 1) - else: - # When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate - # the CGF and obtain an upper bound - alpha_f = math.floor(alpha) - alpha_c = math.ceil(alpha) - - x = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_f) - y = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_c) - t = alpha - alpha_f - return ((1 - t) * x + t * y) / (alpha - 1) - - -def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): - """Compute log(A_alpha) for integer alpha, subsampling without replacement. - - When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly, - otherwise compute the bound with Stirling approximation. - - Args: - q: The sampling proportion = m / n. Assume m is an integer <= n. - sigma: The std of the additive Gaussian noise. - alpha: The order at which RDP is computed. - - Returns: - RDP at alpha, can be np.inf. - """ - - max_alpha = 256 - assert isinstance(alpha, int) - - if np.isinf(alpha): - return np.inf - elif alpha == 1: - return 0 - - def cgf(x): - # Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2) - return x * 1.0 * (x + 1) / (2.0 * sigma**2) - - def func(x): - # Return the rdp of Gaussian mechanism - return 1.0 * x / (2.0 * sigma**2) - - # Initialize with 1 in the log space. - log_a = 0 - # Calculates the log term when alpha = 2 - log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0))) - if alpha <= max_alpha: - # We need forward differences of exp(cgf) - # The following line is the numerically stable way of implementing it. - # The output is in polar form with logarithmic magnitude - deltas, _ = _get_forward_diffs(cgf, alpha) - # Compute the bound exactly requires book keeping of O(alpha**2) - - for i in range(2, alpha + 1): - if i == 2: - s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum( - np.log(4) + log_f2m1, - func(2.0) + np.log(2)) - elif i > 2: - delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1] - delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1] - s = np.log(4) + 0.5 * (delta_lo + delta_hi) - s = np.minimum(s, np.log(2) + cgf(i - 1)) - s += i * np.log(q) + _log_comb(alpha, i) - log_a = _log_add(log_a, s) - return float(log_a) - else: - # Compute the bound with stirling approximation. Everything is O(x) now. - for i in range(2, alpha + 1): - if i == 2: - s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum( - np.log(4) + log_f2m1, - func(2.0) + np.log(2)) - else: - s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i) - log_a = _log_add(log_a, s) - - return log_a + return _compute_rdp_from_event(orders, event, steps) def compute_heterogeneous_rdp(sampling_probabilities, noise_multipliers, steps_list, orders): - """Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms. + """(Deprecated) Computes RDP of Heteregoneous Sampled Gaussian Mechanisms. + + This function has been superseded by more general accounting mechanisms in + Google's `differential_privacy` package. It may at some future date be + removed. Args: sampling_probabilities: A list containing the sampling rates. @@ -563,7 +166,11 @@ def compute_heterogeneous_rdp(sampling_probabilities, noise_multipliers, def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None): - """Computes delta (or eps) for given eps (or delta) from RDP values. + """(Deprecated) Computes delta or eps from RDP values. + + This function has been superseded by more general accounting mechanisms in + Google's `differential_privacy` package. It may at some future date be + removed. Args: orders: An array (or a scalar) of RDP orders. @@ -588,9 +195,12 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None): raise ValueError( "Exactly one out of eps and delta must be None. (None is).") + accountant = rdp_privacy_accountant.RdpAccountant(orders) + accountant._rdp = rdp # pylint: disable=protected-access + if target_eps is not None: - delta, opt_order = _compute_delta(orders, rdp, target_eps) + delta, opt_order = accountant.get_delta_and_optimal_order(target_eps) return target_eps, delta, opt_order else: - eps, opt_order = _compute_eps(orders, rdp, target_delta) + eps, opt_order = accountant.get_epsilon_and_optimal_order(target_delta) return eps, target_delta, opt_order diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py index 4432487..92a1ad2 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py @@ -23,51 +23,58 @@ import tensorflow as tf from tensorflow_privacy.privacy.analysis import rdp_accountant +################################# +# HELPER FUNCTIONS: # +# Exact computations using # +# multi-precision arithmetic. # +################################# + + +def _log_float_mp(x): + # Convert multi-precision input to float log space. + if x >= sys.float_info.min: + return float(mpmath.log(x)) + else: + return -np.inf + + +def _integral_mp(fn, bounds=(-mpmath.inf, mpmath.inf)): + integral, _ = mpmath.quad(fn, bounds, error=True, maxdegree=8) + return integral + + +def _distributions_mp(sigma, q): + + def _mu0(x): + return mpmath.npdf(x, mu=0, sigma=sigma) + + def _mu1(x): + return mpmath.npdf(x, mu=1, sigma=sigma) + + def _mu(x): + return (1 - q) * _mu0(x) + q * _mu1(x) + + return _mu0, _mu # Closure! + + +def _mu1_over_mu0(x, sigma): + # Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x. + return mpmath.exp((2 * x - 1) / (2 * sigma**2)) + + +def _mu_over_mu0(x, q, sigma): + return (1 - q) + q * _mu1_over_mu0(x, sigma) + + +def _compute_a_mp(sigma, q, alpha): + """Compute A_alpha for arbitrary alpha by numerical integration.""" + mu0, _ = _distributions_mp(sigma, q) + a_alpha_fn = lambda z: mu0(z) * _mu_over_mu0(z, q, sigma)**alpha + a_alpha = _integral_mp(a_alpha_fn) + return a_alpha + class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase): - ################################# - # HELPER FUNCTIONS: # - # Exact computations using # - # multi-precision arithmetic. # - ################################# - - def _log_float_mp(self, x): - # Convert multi-precision input to float log space. - if x >= sys.float_info.min: - return float(mpmath.log(x)) - else: - return -np.inf - - def _integral_mp(self, fn, bounds=(-mpmath.inf, mpmath.inf)): - integral, _ = mpmath.quad(fn, bounds, error=True, maxdegree=8) - return integral - - def _distributions_mp(self, sigma, q): - - def _mu0(x): - return mpmath.npdf(x, mu=0, sigma=sigma) - - def _mu1(x): - return mpmath.npdf(x, mu=1, sigma=sigma) - - def _mu(x): - return (1 - q) * _mu0(x) + q * _mu1(x) - - return _mu0, _mu # Closure! - - def _mu1_over_mu0(self, x, sigma): - # Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x. - return mpmath.exp((2 * x - 1) / (2 * sigma**2)) - - def _mu_over_mu0(self, x, q, sigma): - return (1 - q) + q * self._mu1_over_mu0(x, sigma) - - def _compute_a_mp(self, sigma, q, alpha): - """Compute A_alpha for arbitrary alpha by numerical integration.""" - mu0, _ = self._distributions_mp(sigma, q) - a_alpha_fn = lambda z: mu0(z) * self._mu_over_mu0(z, q, sigma)**alpha - a_alpha = self._integral_mp(a_alpha_fn) - return a_alpha # TEST ROUTINES def test_compute_heterogeneous_rdp_different_sampling_probabilities(self): @@ -152,15 +159,6 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase): 'order': 256.1 }) - # pylint:disable=undefined-variable - @parameterized.parameters(p for p in params) - def test_compute_log_a_equals_mp(self, q, sigma, order): - # Compare the cheap computation of log(A) with an expensive, multi-precision - # computation. - log_a = rdp_accountant._compute_log_a(q, sigma, order) - log_a_mp = self._log_float_mp(self._compute_a_mp(sigma, q, order)) - np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4) - def test_get_privacy_spent_check_target_delta(self): orders = range(2, 33) rdp = [1.1 for o in orders] # Constant corresponds to pure DP.