From c0d3431eb245b5e0f2dfd5e8dc1559e7b0c5be3c Mon Sep 17 00:00:00 2001 From: Yuqing Date: Fri, 12 Mar 2021 13:56:52 -0800 Subject: [PATCH 1/4] add rdp for subsample without replacement --- .../privacy/analysis/.gitignore | 2 + .../privacy/analysis/rdp_accountant.py | 190 +++++++++++++++++- .../privacy/analysis/rdp_accountant_test.py | 7 + 3 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 tensorflow_privacy/privacy/analysis/.gitignore diff --git a/tensorflow_privacy/privacy/analysis/.gitignore b/tensorflow_privacy/privacy/analysis/.gitignore new file mode 100644 index 0000000..a47fb1d --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/.gitignore @@ -0,0 +1,2 @@ +.idea +__pycache__ \ No newline at end of file diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py index 59e33c0..4059532 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """RDP analysis of the Sampled Gaussian Mechanism. - Functionality for computing Renyi differential privacy (RDP) of an additive Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods: compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated @@ -46,6 +45,7 @@ import sys import numpy as np from scipy import special import six +import rdp_utils ######################## # LOG-SPACE ARITHMETIC # @@ -76,6 +76,21 @@ def _log_sub(logx, logy): except OverflowError: return logx +def _log_sub_sign(logx, logy): + # ensure that x > y + # this function returns the stable version of log(exp(logx)-exp(logy)) if logx > logy + if logx > logy: + s = True + mag = logx + np.log(1 - np.exp(logy - logx)) + elif logx < logy: + s = False + mag = logy + np.log(1 - np.exp(logx - logy)) + else: + s = True + mag = -np.inf + + return s, mag + def _log_print(logx): """Pretty print.""" @@ -268,6 +283,57 @@ def _compute_eps(orders, rdp, delta): idx_opt = np.argmin(eps_vec) return max(0, eps_vec[idx_opt]), orders_vec[idx_opt] +def _stable_inplace_diff_in_log(vec, signs, n=-1): + + """ This function replaces the first n-1 dimension of vec with the log of abs difference operator + Input: + - `vec` is a numpy array of floats with size larger than 'n' + - `signs` is a numpy array of bools with the same size as vec + - `n` is an optional argument in case one needs to compute partial differences + `vec` and `signs` jointly describe a vector of real numbers' sign and abs in log scale. + Output: + The first n-1 dimension of vec and signs will store the log-abs and sign of the difference. + """ + # + # And the first n-1 dimension of signs with the sign of the differences. + # the sign is assigned to True to break symmetry if the diff is 0 + # Input: + assert (vec.shape == signs.shape) + if n < 0: + n = np.max(vec.shape) - 1 + else: + assert (np.max(vec.shape) >= n + 1) + for j in range(0, n, 1): + if signs[j] == signs[j + 1]: # When the signs are the same + # if the signs are both positive, then we can just use the standard one + signs[j], vec[j] = _log_sub_sign(vec[j + 1],vec[j]) + # otherwise, we do that but toggle the sign + if signs[j + 1] == False: + signs[j] = ~signs[j] + else: # When the signs are different. + vec[j] = _log_add(vec[j], vec[j + 1]) + signs[j] = signs[j + 1] + + +def _get_forward_diffs(fun, n): + """ + This is the key function for computing up to nth order forward difference evaluated at 0, used for Subsample Gaussian mechanism + See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf + """ + # Pre-compute the finite difference operators + # Save them in log-scale + func_vec = np.zeros(n + 3) + signs_func_vec = np.ones(n + 3, dtype=bool) + deltas = np.zeros(n + 2) # ith coordinate of deltas stores log(abs(ith order discrete derivative)) + signs_deltas = np.zeros(n + 2, dtype=bool) + for i in range(1, n + 3, 1): + func_vec[i] = fun(1.0 * (i - 1)) + for i in range(0, n + 2, 1): + # Diff in log scale + _stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i) + deltas[i] = func_vec[0] + signs_deltas[i] = signs_func_vec[0] + return deltas, signs_deltas def _compute_rdp(q, sigma, alpha): """Compute RDP of the Sampled Gaussian mechanism at order alpha. @@ -313,6 +379,128 @@ def compute_rdp(q, noise_multiplier, steps, orders): return rdp * steps +def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders): + + """Compute RDP of the Sampled Gaussian Mechanism using sampling without replacement. + This function applies to the following schemes: + 1. Sampling without replacement: Sample a uniformly random subset of size m = q*n. + 2. ``Replace one data point'' version of differential privacy, i.e., n is considered public + information. + Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened version applies subsampled-Gaussian mechanism) + - Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and Analytical Moments + Accountant." AISTATS'2019. + + Args: + q: The sampling proportion = m / n. Assume m is an integer <= n. + noise_multiplier: The ratio of the standard deviation of the Gaussian noise + to the l2-sensitivity of the function to which it is added. + steps: The number of steps. + orders: An array (or a scalar) of RDP orders. + Returns: + The RDPs at all orders, can be np.inf. + """ + if np.isscalar(orders): + rdp = _compute_rdp_sample_without_replacement_scalar(q, noise_multiplier, orders) + else: + rdp = np.array([_compute_rdp_sample_without_replacement_scalar(q, noise_multiplier, order) + for order in orders]) + + return rdp * steps + +def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha): + """Compute RDP of the Sampled Gaussian mechanism at order alpha. + Args: + q: The sampling proportion = m / n. Assume m is an integer <= n. + sigma: The std of the additive Gaussian noise. + alpha: The order at which RDP is computed. + Returns: + RDP at alpha, can be np.inf. + """ + + assert (q <= 1) and (q >= 0) and (alpha >= 1) + + if q == 0: + return 0 + + if q == 1.: + return alpha / (2 * sigma**2) + + if np.isinf(alpha): + return np.inf + + + + if isinstance(alpha, six.integer_types): + return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (alpha - 1) + else: + # When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate the + # CGF and obtain an upper bound + alpha_f = math.floor(alpha) + alpha_c = math.ceil(alpha) + + x = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_f) + y = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_c) + t = alpha - alpha_f + return ((1-t) * x + t * y) / (alpha-1) + +def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): + """Compute log(A_alpha) for integer alpha. 0 < q < 1, under subsampling without replacement. + when alpha is smaller than max_alpha, compute the bound Theorem 27 exactly, else compute the bound with stirling approximation + Args: + q: The sampling proportion = m / n. Assume m is an integer <= n. + sigma: The std of the additive Gaussian noise. + alpha: The order at which RDP is computed. + Returns: + RDP at alpha, can be np.inf. + """ + + max_alpha = 100 + assert isinstance(alpha, six.integer_types) + + if np.isinf(alpha): + return np.inf + elif alpha==1: + return 0 + + def cgf(x): + # Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2) + return x*1.0*(x+1)/(2.0*sigma**2) + + def func(x): + # Return the rdp of Gaussian mechanism + return 1.0*(x)/(2.0*sigma**2) + + # We need forward differences of exp(cgf) + # The following line is the numerically stable way of implementing it. + # The output is in polar form with logarithmic magnitude + deltas, signs_deltas = _get_forward_diffs(cgf, alpha) + + + # Initialize with 1 in the log space. + log_a = 0 + if alpha <= max_alpha: + # Compute the bound exactly requires book keeping of O(alpha**2) + + for i in range(2, alpha+1): + if i == 2: + s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))),func(2.0) + np.log(2)) + elif i > 2: + s = np.minimum(np.log(4) + 0.5*deltas[int(2*np.floor(i/2.0))-1]+ 0.5*deltas[int(2*np.ceil(i/2.0))-1],np.log(2)+ cgf(i - 1)) \ + + i * np.log(q) +_log_comb(alpha, i) + log_a = _log_add(log_a,s) + return float(log_a) + else: + # Compute the bound with stirling approximation. Everything is O(x) now. + for i in range(2, alpha + 1): + if i == 2: + s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum( + np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))), func(2.0) + np.log(2)) + else: + s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i) + log_a = _log_add(log_a, s) + + return log_a + def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers, steps_list, orders): diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py index eda62bb..ee8278b 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py @@ -102,6 +102,13 @@ class TestGaussianMoments(parameterized.TestCase): rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5) self.assertAlmostEqual(rdp_scalar, 0.07737, places=5) + def test_compute_rdp_sequence_without_replacement(self): + rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(0.01, 2.5, 50, + [1.001, 1.5, 2.5, 5, 50, 100, np.inf]) + self.assertSequenceAlmostEqual( + rdp_vec, [0.003470,0.003470, 0.004638, 0.0087633, 0.09847, 167.766388, np.inf], + delta=1e-5) + def test_compute_rdp_sequence(self): rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50, [1.5, 2.5, 5, 50, 100, np.inf]) From 736520b0eb7d162509b142bbc7e86de4bd049886 Mon Sep 17 00:00:00 2001 From: Yuqing Date: Fri, 12 Mar 2021 14:00:53 -0800 Subject: [PATCH 2/4] remove unnecessary files --- tensorflow_privacy/privacy/analysis/rdp_accountant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py index 4059532..721f105 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -45,7 +45,7 @@ import sys import numpy as np from scipy import special import six -import rdp_utils + ######################## # LOG-SPACE ARITHMETIC # From 09270afed61a4e9ef6295217b53731f7515f0350 Mon Sep 17 00:00:00 2001 From: Yuqing Date: Fri, 7 May 2021 00:16:59 -0700 Subject: [PATCH 3/4] Resolve comments and add more tests --- .../privacy/analysis/rdp_accountant.py | 58 ++++++++++--------- .../privacy/analysis/rdp_accountant_test.py | 4 +- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py index 721f105..6889ad5 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -77,8 +77,7 @@ def _log_sub(logx, logy): return logx def _log_sub_sign(logx, logy): - # ensure that x > y - # this function returns the stable version of log(exp(logx)-exp(logy)) if logx > logy + """Returns log(exp(logx)-exp(logy)) and its sign.""" if logx > logy: s = True mag = logx + np.log(1 - np.exp(logy - logx)) @@ -286,18 +285,19 @@ def _compute_eps(orders, rdp, delta): def _stable_inplace_diff_in_log(vec, signs, n=-1): """ This function replaces the first n-1 dimension of vec with the log of abs difference operator - Input: - - `vec` is a numpy array of floats with size larger than 'n' - - `signs` is a numpy array of bools with the same size as vec - - `n` is an optional argument in case one needs to compute partial differences - `vec` and `signs` jointly describe a vector of real numbers' sign and abs in log scale. - Output: - The first n-1 dimension of vec and signs will store the log-abs and sign of the difference. - """ - # - # And the first n-1 dimension of signs with the sign of the differences. - # the sign is assigned to True to break symmetry if the diff is 0 - # Input: + + Args: + vec: is a numpy array of floats with size larger than 'n' + signs: is a numpy array of bools with the same size as vec is an optional argument in case one needs to compute partial differences + vec and signs jointly describe a vector of real numbers' sign and abs in log scale. + + Returns: + The first n-1 dimension of vec and signs will store the log-abs and sign of the difference. + + Raises: + ValueError: If input is malformed. + """ + assert (vec.shape == signs.shape) if n < 0: n = np.max(vec.shape) - 1 @@ -306,7 +306,7 @@ def _stable_inplace_diff_in_log(vec, signs, n=-1): for j in range(0, n, 1): if signs[j] == signs[j + 1]: # When the signs are the same # if the signs are both positive, then we can just use the standard one - signs[j], vec[j] = _log_sub_sign(vec[j + 1],vec[j]) + signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j]) # otherwise, we do that but toggle the sign if signs[j + 1] == False: signs[j] = ~signs[j] @@ -428,9 +428,7 @@ def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha): if np.isinf(alpha): return np.inf - - - if isinstance(alpha, six.integer_types): + if float(alpha).is_integer(): return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (alpha - 1) else: # When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate the @@ -454,7 +452,7 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): RDP at alpha, can be np.inf. """ - max_alpha = 100 + max_alpha = 256 assert isinstance(alpha, six.integer_types) if np.isinf(alpha): @@ -470,23 +468,28 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): # Return the rdp of Gaussian mechanism return 1.0*(x)/(2.0*sigma**2) - # We need forward differences of exp(cgf) - # The following line is the numerically stable way of implementing it. - # The output is in polar form with logarithmic magnitude - deltas, signs_deltas = _get_forward_diffs(cgf, alpha) # Initialize with 1 in the log space. log_a = 0 + # Calculates the log term when alpha = 2 + log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0))) if alpha <= max_alpha: + # We need forward differences of exp(cgf) + # The following line is the numerically stable way of implementing it. + # The output is in polar form with logarithmic magnitude + deltas, signs_deltas = _get_forward_diffs(cgf, alpha) # Compute the bound exactly requires book keeping of O(alpha**2) for i in range(2, alpha+1): if i == 2: - s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))),func(2.0) + np.log(2)) + s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + log_f2m1, func(2.0) + np.log(2)) elif i > 2: - s = np.minimum(np.log(4) + 0.5*deltas[int(2*np.floor(i/2.0))-1]+ 0.5*deltas[int(2*np.ceil(i/2.0))-1],np.log(2)+ cgf(i - 1)) \ - + i * np.log(q) +_log_comb(alpha, i) + delta_lo = deltas[int(2*np.floor(i/2.0))-1] + delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1] + s = np.log(4) + 0.5 * (delta_lo + delta_hi) + s = np.minimum(s, np.log(2) + cgf(i - 1)) + s += i * np.log(q) + _log_comb(alpha, i) log_a = _log_add(log_a,s) return float(log_a) else: @@ -494,7 +497,7 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): for i in range(2, alpha + 1): if i == 2: s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum( - np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))), func(2.0) + np.log(2)) + np.log(4) + log_f2m1, func(2.0) + np.log(2)) else: s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i) log_a = _log_add(log_a, s) @@ -580,3 +583,4 @@ def compute_rdp_from_ledger(ledger, orders): total_rdp += compute_rdp( sample.selection_probability, effective_z, 1, orders) return total_rdp + diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py index ee8278b..6e2e9d2 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py @@ -104,9 +104,9 @@ class TestGaussianMoments(parameterized.TestCase): def test_compute_rdp_sequence_without_replacement(self): rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(0.01, 2.5, 50, - [1.001, 1.5, 2.5, 5, 50, 100, np.inf]) + [1.001, 1.5, 2.5, 5, 50, 100, 256, 512, 1024, np.inf]) self.assertSequenceAlmostEqual( - rdp_vec, [0.003470,0.003470, 0.004638, 0.0087633, 0.09847, 167.766388, np.inf], + rdp_vec, [0.003470, 0.003470, 0.004638, 0.0087633, 0.098474, 167.766388, 792.838516, 1817.35871, 3865.55029, np.inf], delta=1e-5) def test_compute_rdp_sequence(self): From 9d133767076116447fdb6b364625aeea67a9301a Mon Sep 17 00:00:00 2001 From: Yuqing Date: Tue, 11 May 2021 00:19:52 -0700 Subject: [PATCH 4/4] resolve space issues --- .../privacy/analysis/rdp_accountant.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py index 6889ad5..03917bb 100644 --- a/tensorflow_privacy/privacy/analysis/rdp_accountant.py +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -273,7 +273,7 @@ def _compute_eps(orders, rdp, delta): # This bound is not numerically stable as alpha->1. # Thus we have a min value of alpha. # The bound is also not useful for small alpha, so doesn't matter. - eps = r + math.log1p(-1/a) - math.log(delta * a) / (a - 1) + eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1) else: # In this case we can't do anything. E.g., asking for delta = 0. eps = np.inf @@ -330,7 +330,7 @@ def _get_forward_diffs(fun, n): func_vec[i] = fun(1.0 * (i - 1)) for i in range(0, n + 2, 1): # Diff in log scale - _stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i) + _stable_inplace_diff_in_log(func_vec, signs_func_vec, n = n + 2 - i) deltas[i] = func_vec[0] signs_deltas[i] = signs_func_vec[0] return deltas, signs_deltas @@ -462,11 +462,11 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): def cgf(x): # Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2) - return x*1.0*(x+1)/(2.0*sigma**2) + return x * 1.0 * (x + 1) / (2.0 * sigma**2) def func(x): # Return the rdp of Gaussian mechanism - return 1.0*(x)/(2.0*sigma**2) + return 1.0 * x / (2.0 * sigma**2) @@ -485,7 +485,7 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): if i == 2: s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + log_f2m1, func(2.0) + np.log(2)) elif i > 2: - delta_lo = deltas[int(2*np.floor(i/2.0))-1] + delta_lo = deltas[int(2 * np.floor( i / 2.0))-1] delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1] s = np.log(4) + 0.5 * (delta_lo + delta_hi) s = np.minimum(s, np.log(2) + cgf(i - 1)) @@ -496,10 +496,10 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha): # Compute the bound with stirling approximation. Everything is O(x) now. for i in range(2, alpha + 1): if i == 2: - s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum( + s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum( np.log(4) + log_f2m1, func(2.0) + np.log(2)) else: - s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i) + s = np.log(2) + cgf(i - 1) + i*np.log(q) + _log_comb(alpha, i) log_a = _log_add(log_a, s) return log_a