add rdp for subsample without replacement
This commit is contained in:
parent
5524409cbd
commit
c0d3431eb2
3 changed files with 198 additions and 1 deletions
2
tensorflow_privacy/privacy/analysis/.gitignore
vendored
Normal file
2
tensorflow_privacy/privacy/analysis/.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
.idea
|
||||||
|
__pycache__
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""RDP analysis of the Sampled Gaussian Mechanism.
|
"""RDP analysis of the Sampled Gaussian Mechanism.
|
||||||
|
|
||||||
Functionality for computing Renyi differential privacy (RDP) of an additive
|
Functionality for computing Renyi differential privacy (RDP) of an additive
|
||||||
Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods:
|
Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods:
|
||||||
compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated
|
compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated
|
||||||
|
@ -46,6 +45,7 @@ import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import special
|
from scipy import special
|
||||||
import six
|
import six
|
||||||
|
import rdp_utils
|
||||||
|
|
||||||
########################
|
########################
|
||||||
# LOG-SPACE ARITHMETIC #
|
# LOG-SPACE ARITHMETIC #
|
||||||
|
@ -76,6 +76,21 @@ def _log_sub(logx, logy):
|
||||||
except OverflowError:
|
except OverflowError:
|
||||||
return logx
|
return logx
|
||||||
|
|
||||||
|
def _log_sub_sign(logx, logy):
|
||||||
|
# ensure that x > y
|
||||||
|
# this function returns the stable version of log(exp(logx)-exp(logy)) if logx > logy
|
||||||
|
if logx > logy:
|
||||||
|
s = True
|
||||||
|
mag = logx + np.log(1 - np.exp(logy - logx))
|
||||||
|
elif logx < logy:
|
||||||
|
s = False
|
||||||
|
mag = logy + np.log(1 - np.exp(logx - logy))
|
||||||
|
else:
|
||||||
|
s = True
|
||||||
|
mag = -np.inf
|
||||||
|
|
||||||
|
return s, mag
|
||||||
|
|
||||||
|
|
||||||
def _log_print(logx):
|
def _log_print(logx):
|
||||||
"""Pretty print."""
|
"""Pretty print."""
|
||||||
|
@ -268,6 +283,57 @@ def _compute_eps(orders, rdp, delta):
|
||||||
idx_opt = np.argmin(eps_vec)
|
idx_opt = np.argmin(eps_vec)
|
||||||
return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
|
return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
|
||||||
|
|
||||||
|
def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
||||||
|
|
||||||
|
""" This function replaces the first n-1 dimension of vec with the log of abs difference operator
|
||||||
|
Input:
|
||||||
|
- `vec` is a numpy array of floats with size larger than 'n'
|
||||||
|
- `signs` is a numpy array of bools with the same size as vec
|
||||||
|
- `n` is an optional argument in case one needs to compute partial differences
|
||||||
|
`vec` and `signs` jointly describe a vector of real numbers' sign and abs in log scale.
|
||||||
|
Output:
|
||||||
|
The first n-1 dimension of vec and signs will store the log-abs and sign of the difference.
|
||||||
|
"""
|
||||||
|
#
|
||||||
|
# And the first n-1 dimension of signs with the sign of the differences.
|
||||||
|
# the sign is assigned to True to break symmetry if the diff is 0
|
||||||
|
# Input:
|
||||||
|
assert (vec.shape == signs.shape)
|
||||||
|
if n < 0:
|
||||||
|
n = np.max(vec.shape) - 1
|
||||||
|
else:
|
||||||
|
assert (np.max(vec.shape) >= n + 1)
|
||||||
|
for j in range(0, n, 1):
|
||||||
|
if signs[j] == signs[j + 1]: # When the signs are the same
|
||||||
|
# if the signs are both positive, then we can just use the standard one
|
||||||
|
signs[j], vec[j] = _log_sub_sign(vec[j + 1],vec[j])
|
||||||
|
# otherwise, we do that but toggle the sign
|
||||||
|
if signs[j + 1] == False:
|
||||||
|
signs[j] = ~signs[j]
|
||||||
|
else: # When the signs are different.
|
||||||
|
vec[j] = _log_add(vec[j], vec[j + 1])
|
||||||
|
signs[j] = signs[j + 1]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_forward_diffs(fun, n):
|
||||||
|
"""
|
||||||
|
This is the key function for computing up to nth order forward difference evaluated at 0, used for Subsample Gaussian mechanism
|
||||||
|
See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf
|
||||||
|
"""
|
||||||
|
# Pre-compute the finite difference operators
|
||||||
|
# Save them in log-scale
|
||||||
|
func_vec = np.zeros(n + 3)
|
||||||
|
signs_func_vec = np.ones(n + 3, dtype=bool)
|
||||||
|
deltas = np.zeros(n + 2) # ith coordinate of deltas stores log(abs(ith order discrete derivative))
|
||||||
|
signs_deltas = np.zeros(n + 2, dtype=bool)
|
||||||
|
for i in range(1, n + 3, 1):
|
||||||
|
func_vec[i] = fun(1.0 * (i - 1))
|
||||||
|
for i in range(0, n + 2, 1):
|
||||||
|
# Diff in log scale
|
||||||
|
_stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i)
|
||||||
|
deltas[i] = func_vec[0]
|
||||||
|
signs_deltas[i] = signs_func_vec[0]
|
||||||
|
return deltas, signs_deltas
|
||||||
|
|
||||||
def _compute_rdp(q, sigma, alpha):
|
def _compute_rdp(q, sigma, alpha):
|
||||||
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||||
|
@ -313,6 +379,128 @@ def compute_rdp(q, noise_multiplier, steps, orders):
|
||||||
|
|
||||||
return rdp * steps
|
return rdp * steps
|
||||||
|
|
||||||
|
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
||||||
|
|
||||||
|
"""Compute RDP of the Sampled Gaussian Mechanism using sampling without replacement.
|
||||||
|
This function applies to the following schemes:
|
||||||
|
1. Sampling without replacement: Sample a uniformly random subset of size m = q*n.
|
||||||
|
2. ``Replace one data point'' version of differential privacy, i.e., n is considered public
|
||||||
|
information.
|
||||||
|
Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened version applies subsampled-Gaussian mechanism)
|
||||||
|
- Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and Analytical Moments
|
||||||
|
Accountant." AISTATS'2019.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||||
|
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
|
||||||
|
to the l2-sensitivity of the function to which it is added.
|
||||||
|
steps: The number of steps.
|
||||||
|
orders: An array (or a scalar) of RDP orders.
|
||||||
|
Returns:
|
||||||
|
The RDPs at all orders, can be np.inf.
|
||||||
|
"""
|
||||||
|
if np.isscalar(orders):
|
||||||
|
rdp = _compute_rdp_sample_without_replacement_scalar(q, noise_multiplier, orders)
|
||||||
|
else:
|
||||||
|
rdp = np.array([_compute_rdp_sample_without_replacement_scalar(q, noise_multiplier, order)
|
||||||
|
for order in orders])
|
||||||
|
|
||||||
|
return rdp * steps
|
||||||
|
|
||||||
|
def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha):
|
||||||
|
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||||
|
Args:
|
||||||
|
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||||
|
sigma: The std of the additive Gaussian noise.
|
||||||
|
alpha: The order at which RDP is computed.
|
||||||
|
Returns:
|
||||||
|
RDP at alpha, can be np.inf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (q <= 1) and (q >= 0) and (alpha >= 1)
|
||||||
|
|
||||||
|
if q == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if q == 1.:
|
||||||
|
return alpha / (2 * sigma**2)
|
||||||
|
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if isinstance(alpha, six.integer_types):
|
||||||
|
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (alpha - 1)
|
||||||
|
else:
|
||||||
|
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate the
|
||||||
|
# CGF and obtain an upper bound
|
||||||
|
alpha_f = math.floor(alpha)
|
||||||
|
alpha_c = math.ceil(alpha)
|
||||||
|
|
||||||
|
x = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_f)
|
||||||
|
y = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_c)
|
||||||
|
t = alpha - alpha_f
|
||||||
|
return ((1-t) * x + t * y) / (alpha-1)
|
||||||
|
|
||||||
|
def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
|
||||||
|
"""Compute log(A_alpha) for integer alpha. 0 < q < 1, under subsampling without replacement.
|
||||||
|
when alpha is smaller than max_alpha, compute the bound Theorem 27 exactly, else compute the bound with stirling approximation
|
||||||
|
Args:
|
||||||
|
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||||
|
sigma: The std of the additive Gaussian noise.
|
||||||
|
alpha: The order at which RDP is computed.
|
||||||
|
Returns:
|
||||||
|
RDP at alpha, can be np.inf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_alpha = 100
|
||||||
|
assert isinstance(alpha, six.integer_types)
|
||||||
|
|
||||||
|
if np.isinf(alpha):
|
||||||
|
return np.inf
|
||||||
|
elif alpha==1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def cgf(x):
|
||||||
|
# Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2)
|
||||||
|
return x*1.0*(x+1)/(2.0*sigma**2)
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
# Return the rdp of Gaussian mechanism
|
||||||
|
return 1.0*(x)/(2.0*sigma**2)
|
||||||
|
|
||||||
|
# We need forward differences of exp(cgf)
|
||||||
|
# The following line is the numerically stable way of implementing it.
|
||||||
|
# The output is in polar form with logarithmic magnitude
|
||||||
|
deltas, signs_deltas = _get_forward_diffs(cgf, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize with 1 in the log space.
|
||||||
|
log_a = 0
|
||||||
|
if alpha <= max_alpha:
|
||||||
|
# Compute the bound exactly requires book keeping of O(alpha**2)
|
||||||
|
|
||||||
|
for i in range(2, alpha+1):
|
||||||
|
if i == 2:
|
||||||
|
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))),func(2.0) + np.log(2))
|
||||||
|
elif i > 2:
|
||||||
|
s = np.minimum(np.log(4) + 0.5*deltas[int(2*np.floor(i/2.0))-1]+ 0.5*deltas[int(2*np.ceil(i/2.0))-1],np.log(2)+ cgf(i - 1)) \
|
||||||
|
+ i * np.log(q) +_log_comb(alpha, i)
|
||||||
|
log_a = _log_add(log_a,s)
|
||||||
|
return float(log_a)
|
||||||
|
else:
|
||||||
|
# Compute the bound with stirling approximation. Everything is O(x) now.
|
||||||
|
for i in range(2, alpha + 1):
|
||||||
|
if i == 2:
|
||||||
|
s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum(
|
||||||
|
np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))), func(2.0) + np.log(2))
|
||||||
|
else:
|
||||||
|
s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i)
|
||||||
|
log_a = _log_add(log_a, s)
|
||||||
|
|
||||||
|
return log_a
|
||||||
|
|
||||||
|
|
||||||
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
|
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
|
||||||
steps_list, orders):
|
steps_list, orders):
|
||||||
|
|
|
@ -102,6 +102,13 @@ class TestGaussianMoments(parameterized.TestCase):
|
||||||
rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5)
|
rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5)
|
||||||
self.assertAlmostEqual(rdp_scalar, 0.07737, places=5)
|
self.assertAlmostEqual(rdp_scalar, 0.07737, places=5)
|
||||||
|
|
||||||
|
def test_compute_rdp_sequence_without_replacement(self):
|
||||||
|
rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(0.01, 2.5, 50,
|
||||||
|
[1.001, 1.5, 2.5, 5, 50, 100, np.inf])
|
||||||
|
self.assertSequenceAlmostEqual(
|
||||||
|
rdp_vec, [0.003470,0.003470, 0.004638, 0.0087633, 0.09847, 167.766388, np.inf],
|
||||||
|
delta=1e-5)
|
||||||
|
|
||||||
def test_compute_rdp_sequence(self):
|
def test_compute_rdp_sequence(self):
|
||||||
rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50,
|
rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50,
|
||||||
[1.5, 2.5, 5, 50, 100, np.inf])
|
[1.5, 2.5, 5, 50, 100, np.inf])
|
||||||
|
|
Loading…
Reference in a new issue