forked from 626_privacy/tensorflow_privacy
Resolve comments and add more tests
This commit is contained in:
parent
736520b0eb
commit
09270afed6
2 changed files with 33 additions and 29 deletions
|
@ -77,8 +77,7 @@ def _log_sub(logx, logy):
|
||||||
return logx
|
return logx
|
||||||
|
|
||||||
def _log_sub_sign(logx, logy):
|
def _log_sub_sign(logx, logy):
|
||||||
# ensure that x > y
|
"""Returns log(exp(logx)-exp(logy)) and its sign."""
|
||||||
# this function returns the stable version of log(exp(logx)-exp(logy)) if logx > logy
|
|
||||||
if logx > logy:
|
if logx > logy:
|
||||||
s = True
|
s = True
|
||||||
mag = logx + np.log(1 - np.exp(logy - logx))
|
mag = logx + np.log(1 - np.exp(logy - logx))
|
||||||
|
@ -286,18 +285,19 @@ def _compute_eps(orders, rdp, delta):
|
||||||
def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
||||||
|
|
||||||
""" This function replaces the first n-1 dimension of vec with the log of abs difference operator
|
""" This function replaces the first n-1 dimension of vec with the log of abs difference operator
|
||||||
Input:
|
|
||||||
- `vec` is a numpy array of floats with size larger than 'n'
|
Args:
|
||||||
- `signs` is a numpy array of bools with the same size as vec
|
vec: is a numpy array of floats with size larger than 'n'
|
||||||
- `n` is an optional argument in case one needs to compute partial differences
|
signs: is a numpy array of bools with the same size as vec is an optional argument in case one needs to compute partial differences
|
||||||
`vec` and `signs` jointly describe a vector of real numbers' sign and abs in log scale.
|
vec and signs jointly describe a vector of real numbers' sign and abs in log scale.
|
||||||
Output:
|
|
||||||
|
Returns:
|
||||||
The first n-1 dimension of vec and signs will store the log-abs and sign of the difference.
|
The first n-1 dimension of vec and signs will store the log-abs and sign of the difference.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input is malformed.
|
||||||
"""
|
"""
|
||||||
#
|
|
||||||
# And the first n-1 dimension of signs with the sign of the differences.
|
|
||||||
# the sign is assigned to True to break symmetry if the diff is 0
|
|
||||||
# Input:
|
|
||||||
assert (vec.shape == signs.shape)
|
assert (vec.shape == signs.shape)
|
||||||
if n < 0:
|
if n < 0:
|
||||||
n = np.max(vec.shape) - 1
|
n = np.max(vec.shape) - 1
|
||||||
|
@ -428,9 +428,7 @@ def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha):
|
||||||
if np.isinf(alpha):
|
if np.isinf(alpha):
|
||||||
return np.inf
|
return np.inf
|
||||||
|
|
||||||
|
if float(alpha).is_integer():
|
||||||
|
|
||||||
if isinstance(alpha, six.integer_types):
|
|
||||||
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (alpha - 1)
|
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (alpha - 1)
|
||||||
else:
|
else:
|
||||||
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate the
|
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate the
|
||||||
|
@ -454,7 +452,7 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
|
||||||
RDP at alpha, can be np.inf.
|
RDP at alpha, can be np.inf.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_alpha = 100
|
max_alpha = 256
|
||||||
assert isinstance(alpha, six.integer_types)
|
assert isinstance(alpha, six.integer_types)
|
||||||
|
|
||||||
if np.isinf(alpha):
|
if np.isinf(alpha):
|
||||||
|
@ -470,23 +468,28 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
|
||||||
# Return the rdp of Gaussian mechanism
|
# Return the rdp of Gaussian mechanism
|
||||||
return 1.0*(x)/(2.0*sigma**2)
|
return 1.0*(x)/(2.0*sigma**2)
|
||||||
|
|
||||||
# We need forward differences of exp(cgf)
|
|
||||||
# The following line is the numerically stable way of implementing it.
|
|
||||||
# The output is in polar form with logarithmic magnitude
|
|
||||||
deltas, signs_deltas = _get_forward_diffs(cgf, alpha)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize with 1 in the log space.
|
# Initialize with 1 in the log space.
|
||||||
log_a = 0
|
log_a = 0
|
||||||
|
# Calculates the log term when alpha = 2
|
||||||
|
log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
|
||||||
if alpha <= max_alpha:
|
if alpha <= max_alpha:
|
||||||
|
# We need forward differences of exp(cgf)
|
||||||
|
# The following line is the numerically stable way of implementing it.
|
||||||
|
# The output is in polar form with logarithmic magnitude
|
||||||
|
deltas, signs_deltas = _get_forward_diffs(cgf, alpha)
|
||||||
# Compute the bound exactly requires book keeping of O(alpha**2)
|
# Compute the bound exactly requires book keeping of O(alpha**2)
|
||||||
|
|
||||||
for i in range(2, alpha+1):
|
for i in range(2, alpha+1):
|
||||||
if i == 2:
|
if i == 2:
|
||||||
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))),func(2.0) + np.log(2))
|
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + log_f2m1, func(2.0) + np.log(2))
|
||||||
elif i > 2:
|
elif i > 2:
|
||||||
s = np.minimum(np.log(4) + 0.5*deltas[int(2*np.floor(i/2.0))-1]+ 0.5*deltas[int(2*np.ceil(i/2.0))-1],np.log(2)+ cgf(i - 1)) \
|
delta_lo = deltas[int(2*np.floor(i/2.0))-1]
|
||||||
+ i * np.log(q) +_log_comb(alpha, i)
|
delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1]
|
||||||
|
s = np.log(4) + 0.5 * (delta_lo + delta_hi)
|
||||||
|
s = np.minimum(s, np.log(2) + cgf(i - 1))
|
||||||
|
s += i * np.log(q) + _log_comb(alpha, i)
|
||||||
log_a = _log_add(log_a,s)
|
log_a = _log_add(log_a,s)
|
||||||
return float(log_a)
|
return float(log_a)
|
||||||
else:
|
else:
|
||||||
|
@ -494,7 +497,7 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
|
||||||
for i in range(2, alpha + 1):
|
for i in range(2, alpha + 1):
|
||||||
if i == 2:
|
if i == 2:
|
||||||
s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum(
|
s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum(
|
||||||
np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))), func(2.0) + np.log(2))
|
np.log(4) + log_f2m1, func(2.0) + np.log(2))
|
||||||
else:
|
else:
|
||||||
s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i)
|
s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i)
|
||||||
log_a = _log_add(log_a, s)
|
log_a = _log_add(log_a, s)
|
||||||
|
@ -580,3 +583,4 @@ def compute_rdp_from_ledger(ledger, orders):
|
||||||
total_rdp += compute_rdp(
|
total_rdp += compute_rdp(
|
||||||
sample.selection_probability, effective_z, 1, orders)
|
sample.selection_probability, effective_z, 1, orders)
|
||||||
return total_rdp
|
return total_rdp
|
||||||
|
|
||||||
|
|
|
@ -104,9 +104,9 @@ class TestGaussianMoments(parameterized.TestCase):
|
||||||
|
|
||||||
def test_compute_rdp_sequence_without_replacement(self):
|
def test_compute_rdp_sequence_without_replacement(self):
|
||||||
rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(0.01, 2.5, 50,
|
rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(0.01, 2.5, 50,
|
||||||
[1.001, 1.5, 2.5, 5, 50, 100, np.inf])
|
[1.001, 1.5, 2.5, 5, 50, 100, 256, 512, 1024, np.inf])
|
||||||
self.assertSequenceAlmostEqual(
|
self.assertSequenceAlmostEqual(
|
||||||
rdp_vec, [0.003470,0.003470, 0.004638, 0.0087633, 0.09847, 167.766388, np.inf],
|
rdp_vec, [0.003470, 0.003470, 0.004638, 0.0087633, 0.098474, 167.766388, 792.838516, 1817.35871, 3865.55029, np.inf],
|
||||||
delta=1e-5)
|
delta=1e-5)
|
||||||
|
|
||||||
def test_compute_rdp_sequence(self):
|
def test_compute_rdp_sequence(self):
|
||||||
|
|
Loading…
Reference in a new issue