Merge pull request #158 from jeremy43:improved_gaussian_subsample
PiperOrigin-RevId: 377344012
This commit is contained in:
commit
385fefc85e
2 changed files with 240 additions and 5 deletions
|
@ -47,6 +47,7 @@ import numpy as np
|
|||
from scipy import special
|
||||
import six
|
||||
|
||||
|
||||
########################
|
||||
# LOG-SPACE ARITHMETIC #
|
||||
########################
|
||||
|
@ -77,6 +78,21 @@ def _log_sub(logx, logy):
|
|||
return logx
|
||||
|
||||
|
||||
def _log_sub_sign(logx, logy):
|
||||
"""Returns log(exp(logx)-exp(logy)) and its sign."""
|
||||
if logx > logy:
|
||||
s = True
|
||||
mag = logx + np.log(1 - np.exp(logy - logx))
|
||||
elif logx < logy:
|
||||
s = False
|
||||
mag = logy + np.log(1 - np.exp(logx - logy))
|
||||
else:
|
||||
s = True
|
||||
mag = -np.inf
|
||||
|
||||
return s, mag
|
||||
|
||||
|
||||
def _log_print(logx):
|
||||
"""Pretty print."""
|
||||
if logx < math.log(sys.float_info.max):
|
||||
|
@ -259,7 +275,7 @@ def _compute_eps(orders, rdp, delta):
|
|||
# This bound is not numerically stable as alpha->1.
|
||||
# Thus we have a min value of alpha.
|
||||
# The bound is also not useful for small alpha, so doesn't matter.
|
||||
eps = r + math.log1p(-1/a) - math.log(delta * a) / (a - 1)
|
||||
eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1)
|
||||
else:
|
||||
# In this case we can't do anything. E.g., asking for delta = 0.
|
||||
eps = np.inf
|
||||
|
@ -269,6 +285,70 @@ def _compute_eps(orders, rdp, delta):
|
|||
return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
|
||||
|
||||
|
||||
def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
||||
"""Replaces the first n-1 dims of vec with the log of abs difference operator.
|
||||
|
||||
Args:
|
||||
vec: numpy array of floats with size larger than 'n'
|
||||
signs: Optional numpy array of bools with the same size as vec in case one
|
||||
needs to compute partial differences vec and signs jointly describe a
|
||||
vector of real numbers' sign and abs in log scale.
|
||||
n: Optonal upper bound on number of differences to compute. If negative, all
|
||||
differences are computed.
|
||||
|
||||
Returns:
|
||||
The first n-1 dimension of vec and signs will store the log-abs and sign of
|
||||
the difference.
|
||||
|
||||
Raises:
|
||||
ValueError: If input is malformed.
|
||||
"""
|
||||
|
||||
assert vec.shape == signs.shape
|
||||
if n < 0:
|
||||
n = np.max(vec.shape) - 1
|
||||
else:
|
||||
assert np.max(vec.shape) >= n + 1
|
||||
for j in range(0, n, 1):
|
||||
if signs[j] == signs[j + 1]: # When the signs are the same
|
||||
# if the signs are both positive, then we can just use the standard one
|
||||
signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j])
|
||||
# otherwise, we do that but toggle the sign
|
||||
if not signs[j + 1]:
|
||||
signs[j] = ~signs[j]
|
||||
else: # When the signs are different.
|
||||
vec[j] = _log_add(vec[j], vec[j + 1])
|
||||
signs[j] = signs[j + 1]
|
||||
|
||||
|
||||
def _get_forward_diffs(fun, n):
|
||||
"""Computes up to nth order forward difference evaluated at 0.
|
||||
|
||||
See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf
|
||||
|
||||
Args:
|
||||
fun: Function to compute forward differences of.
|
||||
n: Number of differences to compute.
|
||||
|
||||
Returns:
|
||||
Pair (deltas, signs_deltas) of the log deltas and their signs.
|
||||
"""
|
||||
func_vec = np.zeros(n + 3)
|
||||
signs_func_vec = np.ones(n + 3, dtype=bool)
|
||||
|
||||
# ith coordinate of deltas stores log(abs(ith order discrete derivative))
|
||||
deltas = np.zeros(n + 2)
|
||||
signs_deltas = np.zeros(n + 2, dtype=bool)
|
||||
for i in range(1, n + 3, 1):
|
||||
func_vec[i] = fun(1.0 * (i - 1))
|
||||
for i in range(0, n + 2, 1):
|
||||
# Diff in log scale
|
||||
_stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i)
|
||||
deltas[i] = func_vec[0]
|
||||
signs_deltas[i] = signs_func_vec[0]
|
||||
return deltas, signs_deltas
|
||||
|
||||
|
||||
def _compute_rdp(q, sigma, alpha):
|
||||
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||
|
||||
|
@ -314,6 +394,149 @@ def compute_rdp(q, noise_multiplier, steps, orders):
|
|||
return rdp * steps
|
||||
|
||||
|
||||
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
||||
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
||||
|
||||
This function applies to the following schemes:
|
||||
1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n.
|
||||
2. ``Replace one data point'' version of differential privacy, i.e., n is
|
||||
considered public information.
|
||||
|
||||
Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened
|
||||
version applies subsampled-Gaussian mechanism)
|
||||
- Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and
|
||||
Analytical Moments Accountant." AISTATS'2019.
|
||||
|
||||
Args:
|
||||
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
|
||||
to the l2-sensitivity of the function to which it is added.
|
||||
steps: The number of steps.
|
||||
orders: An array (or a scalar) of RDP orders.
|
||||
|
||||
Returns:
|
||||
The RDPs at all orders, can be np.inf.
|
||||
"""
|
||||
if np.isscalar(orders):
|
||||
rdp = _compute_rdp_sample_without_replacement_scalar(
|
||||
q, noise_multiplier, orders)
|
||||
else:
|
||||
rdp = np.array([
|
||||
_compute_rdp_sample_without_replacement_scalar(q, noise_multiplier,
|
||||
order)
|
||||
for order in orders
|
||||
])
|
||||
|
||||
return rdp * steps
|
||||
|
||||
|
||||
def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha):
|
||||
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||
|
||||
Args:
|
||||
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||
sigma: The std of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
RDP at alpha, can be np.inf.
|
||||
"""
|
||||
|
||||
assert (q <= 1) and (q >= 0) and (alpha >= 1)
|
||||
|
||||
if q == 0:
|
||||
return 0
|
||||
|
||||
if q == 1.:
|
||||
return alpha / (2 * sigma**2)
|
||||
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
|
||||
if float(alpha).is_integer():
|
||||
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (
|
||||
alpha - 1)
|
||||
else:
|
||||
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate
|
||||
# the CGF and obtain an upper bound
|
||||
alpha_f = math.floor(alpha)
|
||||
alpha_c = math.ceil(alpha)
|
||||
|
||||
x = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_f)
|
||||
y = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_c)
|
||||
t = alpha - alpha_f
|
||||
return ((1 - t) * x + t * y) / (alpha - 1)
|
||||
|
||||
|
||||
def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
|
||||
"""Compute log(A_alpha) for integer alpha, subsampling without replacement.
|
||||
|
||||
When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly,
|
||||
otherwise compute the bound with Stirling approximation.
|
||||
|
||||
Args:
|
||||
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||
sigma: The std of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
RDP at alpha, can be np.inf.
|
||||
"""
|
||||
|
||||
max_alpha = 256
|
||||
assert isinstance(alpha, six.integer_types)
|
||||
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
elif alpha == 1:
|
||||
return 0
|
||||
|
||||
def cgf(x):
|
||||
# Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2)
|
||||
return x * 1.0 * (x + 1) / (2.0 * sigma**2)
|
||||
|
||||
def func(x):
|
||||
# Return the rdp of Gaussian mechanism
|
||||
return 1.0 * x / (2.0 * sigma**2)
|
||||
|
||||
# Initialize with 1 in the log space.
|
||||
log_a = 0
|
||||
# Calculates the log term when alpha = 2
|
||||
log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
|
||||
if alpha <= max_alpha:
|
||||
# We need forward differences of exp(cgf)
|
||||
# The following line is the numerically stable way of implementing it.
|
||||
# The output is in polar form with logarithmic magnitude
|
||||
deltas, _ = _get_forward_diffs(cgf, alpha)
|
||||
# Compute the bound exactly requires book keeping of O(alpha**2)
|
||||
|
||||
for i in range(2, alpha + 1):
|
||||
if i == 2:
|
||||
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
|
||||
np.log(4) + log_f2m1,
|
||||
func(2.0) + np.log(2))
|
||||
elif i > 2:
|
||||
delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1]
|
||||
delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1]
|
||||
s = np.log(4) + 0.5 * (delta_lo + delta_hi)
|
||||
s = np.minimum(s, np.log(2) + cgf(i - 1))
|
||||
s += i * np.log(q) + _log_comb(alpha, i)
|
||||
log_a = _log_add(log_a, s)
|
||||
return float(log_a)
|
||||
else:
|
||||
# Compute the bound with stirling approximation. Everything is O(x) now.
|
||||
for i in range(2, alpha + 1):
|
||||
if i == 2:
|
||||
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
|
||||
np.log(4) + log_f2m1,
|
||||
func(2.0) + np.log(2))
|
||||
else:
|
||||
s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i)
|
||||
log_a = _log_add(log_a, s)
|
||||
|
||||
return log_a
|
||||
|
||||
|
||||
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
|
||||
steps_list, orders):
|
||||
"""Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.
|
||||
|
|
|
@ -29,12 +29,13 @@ from mpmath import log
|
|||
from mpmath import npdf
|
||||
from mpmath import quad
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis import privacy_ledger
|
||||
from tensorflow_privacy.privacy.analysis import rdp_accountant
|
||||
|
||||
|
||||
class TestGaussianMoments(parameterized.TestCase):
|
||||
class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
|
||||
#################################
|
||||
# HELPER FUNCTIONS: #
|
||||
# Exact computations using #
|
||||
|
@ -102,12 +103,23 @@ class TestGaussianMoments(parameterized.TestCase):
|
|||
rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5)
|
||||
self.assertAlmostEqual(rdp_scalar, 0.07737, places=5)
|
||||
|
||||
def test_compute_rdp_sequence_without_replacement(self):
|
||||
rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(
|
||||
0.01, 2.5, 50, [1.001, 1.5, 2.5, 5, 50, 100, 256, 512, 1024, np.inf])
|
||||
self.assertAllClose(
|
||||
rdp_vec, [
|
||||
3.4701e-3, 3.4701e-3, 4.6386e-3, 8.7634e-3, 9.8474e-2, 1.6776e2,
|
||||
7.9297e2, 1.8174e3, 3.8656e3, np.inf
|
||||
],
|
||||
rtol=1e-4)
|
||||
|
||||
def test_compute_rdp_sequence(self):
|
||||
rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50,
|
||||
[1.5, 2.5, 5, 50, 100, np.inf])
|
||||
self.assertSequenceAlmostEqual(
|
||||
rdp_vec, [0.00065, 0.001085, 0.00218075, 0.023846, 167.416307, np.inf],
|
||||
delta=1e-5)
|
||||
self.assertAllClose(
|
||||
rdp_vec,
|
||||
[6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02, np.inf],
|
||||
rtol=1e-4)
|
||||
|
||||
params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01},
|
||||
{'q': 1e-6, 'sigma': .1, 'order': 256},
|
||||
|
|
Loading…
Reference in a new issue