Merge pull request #158 from jeremy43:improved_gaussian_subsample

PiperOrigin-RevId: 377344012
This commit is contained in:
A. Unique TensorFlower 2021-06-03 12:13:28 -07:00
commit 385fefc85e
2 changed files with 240 additions and 5 deletions

View file

@ -47,6 +47,7 @@ import numpy as np
from scipy import special from scipy import special
import six import six
######################## ########################
# LOG-SPACE ARITHMETIC # # LOG-SPACE ARITHMETIC #
######################## ########################
@ -77,6 +78,21 @@ def _log_sub(logx, logy):
return logx return logx
def _log_sub_sign(logx, logy):
"""Returns log(exp(logx)-exp(logy)) and its sign."""
if logx > logy:
s = True
mag = logx + np.log(1 - np.exp(logy - logx))
elif logx < logy:
s = False
mag = logy + np.log(1 - np.exp(logx - logy))
else:
s = True
mag = -np.inf
return s, mag
def _log_print(logx): def _log_print(logx):
"""Pretty print.""" """Pretty print."""
if logx < math.log(sys.float_info.max): if logx < math.log(sys.float_info.max):
@ -269,6 +285,70 @@ def _compute_eps(orders, rdp, delta):
return max(0, eps_vec[idx_opt]), orders_vec[idx_opt] return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
def _stable_inplace_diff_in_log(vec, signs, n=-1):
"""Replaces the first n-1 dims of vec with the log of abs difference operator.
Args:
vec: numpy array of floats with size larger than 'n'
signs: Optional numpy array of bools with the same size as vec in case one
needs to compute partial differences vec and signs jointly describe a
vector of real numbers' sign and abs in log scale.
n: Optonal upper bound on number of differences to compute. If negative, all
differences are computed.
Returns:
The first n-1 dimension of vec and signs will store the log-abs and sign of
the difference.
Raises:
ValueError: If input is malformed.
"""
assert vec.shape == signs.shape
if n < 0:
n = np.max(vec.shape) - 1
else:
assert np.max(vec.shape) >= n + 1
for j in range(0, n, 1):
if signs[j] == signs[j + 1]: # When the signs are the same
# if the signs are both positive, then we can just use the standard one
signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j])
# otherwise, we do that but toggle the sign
if not signs[j + 1]:
signs[j] = ~signs[j]
else: # When the signs are different.
vec[j] = _log_add(vec[j], vec[j + 1])
signs[j] = signs[j + 1]
def _get_forward_diffs(fun, n):
"""Computes up to nth order forward difference evaluated at 0.
See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf
Args:
fun: Function to compute forward differences of.
n: Number of differences to compute.
Returns:
Pair (deltas, signs_deltas) of the log deltas and their signs.
"""
func_vec = np.zeros(n + 3)
signs_func_vec = np.ones(n + 3, dtype=bool)
# ith coordinate of deltas stores log(abs(ith order discrete derivative))
deltas = np.zeros(n + 2)
signs_deltas = np.zeros(n + 2, dtype=bool)
for i in range(1, n + 3, 1):
func_vec[i] = fun(1.0 * (i - 1))
for i in range(0, n + 2, 1):
# Diff in log scale
_stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i)
deltas[i] = func_vec[0]
signs_deltas[i] = signs_func_vec[0]
return deltas, signs_deltas
def _compute_rdp(q, sigma, alpha): def _compute_rdp(q, sigma, alpha):
"""Compute RDP of the Sampled Gaussian mechanism at order alpha. """Compute RDP of the Sampled Gaussian mechanism at order alpha.
@ -314,6 +394,149 @@ def compute_rdp(q, noise_multiplier, steps, orders):
return rdp * steps return rdp * steps
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
This function applies to the following schemes:
1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n.
2. ``Replace one data point'' version of differential privacy, i.e., n is
considered public information.
Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened
version applies subsampled-Gaussian mechanism)
- Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and
Analytical Moments Accountant." AISTATS'2019.
Args:
q: The sampling proportion = m / n. Assume m is an integer <= n.
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
to the l2-sensitivity of the function to which it is added.
steps: The number of steps.
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders, can be np.inf.
"""
if np.isscalar(orders):
rdp = _compute_rdp_sample_without_replacement_scalar(
q, noise_multiplier, orders)
else:
rdp = np.array([
_compute_rdp_sample_without_replacement_scalar(q, noise_multiplier,
order)
for order in orders
])
return rdp * steps
def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha):
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
Args:
q: The sampling proportion = m / n. Assume m is an integer <= n.
sigma: The std of the additive Gaussian noise.
alpha: The order at which RDP is computed.
Returns:
RDP at alpha, can be np.inf.
"""
assert (q <= 1) and (q >= 0) and (alpha >= 1)
if q == 0:
return 0
if q == 1.:
return alpha / (2 * sigma**2)
if np.isinf(alpha):
return np.inf
if float(alpha).is_integer():
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (
alpha - 1)
else:
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate
# the CGF and obtain an upper bound
alpha_f = math.floor(alpha)
alpha_c = math.ceil(alpha)
x = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_f)
y = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_c)
t = alpha - alpha_f
return ((1 - t) * x + t * y) / (alpha - 1)
def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
"""Compute log(A_alpha) for integer alpha, subsampling without replacement.
When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly,
otherwise compute the bound with Stirling approximation.
Args:
q: The sampling proportion = m / n. Assume m is an integer <= n.
sigma: The std of the additive Gaussian noise.
alpha: The order at which RDP is computed.
Returns:
RDP at alpha, can be np.inf.
"""
max_alpha = 256
assert isinstance(alpha, six.integer_types)
if np.isinf(alpha):
return np.inf
elif alpha == 1:
return 0
def cgf(x):
# Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2)
return x * 1.0 * (x + 1) / (2.0 * sigma**2)
def func(x):
# Return the rdp of Gaussian mechanism
return 1.0 * x / (2.0 * sigma**2)
# Initialize with 1 in the log space.
log_a = 0
# Calculates the log term when alpha = 2
log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
if alpha <= max_alpha:
# We need forward differences of exp(cgf)
# The following line is the numerically stable way of implementing it.
# The output is in polar form with logarithmic magnitude
deltas, _ = _get_forward_diffs(cgf, alpha)
# Compute the bound exactly requires book keeping of O(alpha**2)
for i in range(2, alpha + 1):
if i == 2:
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
np.log(4) + log_f2m1,
func(2.0) + np.log(2))
elif i > 2:
delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1]
delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1]
s = np.log(4) + 0.5 * (delta_lo + delta_hi)
s = np.minimum(s, np.log(2) + cgf(i - 1))
s += i * np.log(q) + _log_comb(alpha, i)
log_a = _log_add(log_a, s)
return float(log_a)
else:
# Compute the bound with stirling approximation. Everything is O(x) now.
for i in range(2, alpha + 1):
if i == 2:
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
np.log(4) + log_f2m1,
func(2.0) + np.log(2))
else:
s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i)
log_a = _log_add(log_a, s)
return log_a
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers, def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
steps_list, orders): steps_list, orders):
"""Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms. """Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.

View file

@ -29,12 +29,13 @@ from mpmath import log
from mpmath import npdf from mpmath import npdf
from mpmath import quad from mpmath import quad
import numpy as np import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.analysis import rdp_accountant from tensorflow_privacy.privacy.analysis import rdp_accountant
class TestGaussianMoments(parameterized.TestCase): class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
################################# #################################
# HELPER FUNCTIONS: # # HELPER FUNCTIONS: #
# Exact computations using # # Exact computations using #
@ -102,12 +103,23 @@ class TestGaussianMoments(parameterized.TestCase):
rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5) rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5)
self.assertAlmostEqual(rdp_scalar, 0.07737, places=5) self.assertAlmostEqual(rdp_scalar, 0.07737, places=5)
def test_compute_rdp_sequence_without_replacement(self):
rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(
0.01, 2.5, 50, [1.001, 1.5, 2.5, 5, 50, 100, 256, 512, 1024, np.inf])
self.assertAllClose(
rdp_vec, [
3.4701e-3, 3.4701e-3, 4.6386e-3, 8.7634e-3, 9.8474e-2, 1.6776e2,
7.9297e2, 1.8174e3, 3.8656e3, np.inf
],
rtol=1e-4)
def test_compute_rdp_sequence(self): def test_compute_rdp_sequence(self):
rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50, rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50,
[1.5, 2.5, 5, 50, 100, np.inf]) [1.5, 2.5, 5, 50, 100, np.inf])
self.assertSequenceAlmostEqual( self.assertAllClose(
rdp_vec, [0.00065, 0.001085, 0.00218075, 0.023846, 167.416307, np.inf], rdp_vec,
delta=1e-5) [6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02, np.inf],
rtol=1e-4)
params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01}, params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01},
{'q': 1e-6, 'sigma': .1, 'order': 256}, {'q': 1e-6, 'sigma': .1, 'order': 256},