Resolve comments and add more tests

This commit is contained in:
Yuqing 2021-05-07 00:16:59 -07:00
parent 736520b0eb
commit 09270afed6
2 changed files with 33 additions and 29 deletions

View file

@ -77,8 +77,7 @@ def _log_sub(logx, logy):
return logx return logx
def _log_sub_sign(logx, logy): def _log_sub_sign(logx, logy):
# ensure that x > y """Returns log(exp(logx)-exp(logy)) and its sign."""
# this function returns the stable version of log(exp(logx)-exp(logy)) if logx > logy
if logx > logy: if logx > logy:
s = True s = True
mag = logx + np.log(1 - np.exp(logy - logx)) mag = logx + np.log(1 - np.exp(logy - logx))
@ -286,18 +285,19 @@ def _compute_eps(orders, rdp, delta):
def _stable_inplace_diff_in_log(vec, signs, n=-1): def _stable_inplace_diff_in_log(vec, signs, n=-1):
""" This function replaces the first n-1 dimension of vec with the log of abs difference operator """ This function replaces the first n-1 dimension of vec with the log of abs difference operator
Input:
- `vec` is a numpy array of floats with size larger than 'n' Args:
- `signs` is a numpy array of bools with the same size as vec vec: is a numpy array of floats with size larger than 'n'
- `n` is an optional argument in case one needs to compute partial differences signs: is a numpy array of bools with the same size as vec is an optional argument in case one needs to compute partial differences
`vec` and `signs` jointly describe a vector of real numbers' sign and abs in log scale. vec and signs jointly describe a vector of real numbers' sign and abs in log scale.
Output:
Returns:
The first n-1 dimension of vec and signs will store the log-abs and sign of the difference. The first n-1 dimension of vec and signs will store the log-abs and sign of the difference.
Raises:
ValueError: If input is malformed.
""" """
#
# And the first n-1 dimension of signs with the sign of the differences.
# the sign is assigned to True to break symmetry if the diff is 0
# Input:
assert (vec.shape == signs.shape) assert (vec.shape == signs.shape)
if n < 0: if n < 0:
n = np.max(vec.shape) - 1 n = np.max(vec.shape) - 1
@ -306,7 +306,7 @@ def _stable_inplace_diff_in_log(vec, signs, n=-1):
for j in range(0, n, 1): for j in range(0, n, 1):
if signs[j] == signs[j + 1]: # When the signs are the same if signs[j] == signs[j + 1]: # When the signs are the same
# if the signs are both positive, then we can just use the standard one # if the signs are both positive, then we can just use the standard one
signs[j], vec[j] = _log_sub_sign(vec[j + 1],vec[j]) signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j])
# otherwise, we do that but toggle the sign # otherwise, we do that but toggle the sign
if signs[j + 1] == False: if signs[j + 1] == False:
signs[j] = ~signs[j] signs[j] = ~signs[j]
@ -428,9 +428,7 @@ def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha):
if np.isinf(alpha): if np.isinf(alpha):
return np.inf return np.inf
if float(alpha).is_integer():
if isinstance(alpha, six.integer_types):
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (alpha - 1) return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (alpha - 1)
else: else:
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate the # When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate the
@ -454,7 +452,7 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
RDP at alpha, can be np.inf. RDP at alpha, can be np.inf.
""" """
max_alpha = 100 max_alpha = 256
assert isinstance(alpha, six.integer_types) assert isinstance(alpha, six.integer_types)
if np.isinf(alpha): if np.isinf(alpha):
@ -470,23 +468,28 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
# Return the rdp of Gaussian mechanism # Return the rdp of Gaussian mechanism
return 1.0*(x)/(2.0*sigma**2) return 1.0*(x)/(2.0*sigma**2)
# We need forward differences of exp(cgf)
# The following line is the numerically stable way of implementing it.
# The output is in polar form with logarithmic magnitude
deltas, signs_deltas = _get_forward_diffs(cgf, alpha)
# Initialize with 1 in the log space. # Initialize with 1 in the log space.
log_a = 0 log_a = 0
# Calculates the log term when alpha = 2
log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
if alpha <= max_alpha: if alpha <= max_alpha:
# We need forward differences of exp(cgf)
# The following line is the numerically stable way of implementing it.
# The output is in polar form with logarithmic magnitude
deltas, signs_deltas = _get_forward_diffs(cgf, alpha)
# Compute the bound exactly requires book keeping of O(alpha**2) # Compute the bound exactly requires book keeping of O(alpha**2)
for i in range(2, alpha+1): for i in range(2, alpha+1):
if i == 2: if i == 2:
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))),func(2.0) + np.log(2)) s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(np.log(4) + log_f2m1, func(2.0) + np.log(2))
elif i > 2: elif i > 2:
s = np.minimum(np.log(4) + 0.5*deltas[int(2*np.floor(i/2.0))-1]+ 0.5*deltas[int(2*np.ceil(i/2.0))-1],np.log(2)+ cgf(i - 1)) \ delta_lo = deltas[int(2*np.floor(i/2.0))-1]
+ i * np.log(q) +_log_comb(alpha, i) delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1]
s = np.log(4) + 0.5 * (delta_lo + delta_hi)
s = np.minimum(s, np.log(2) + cgf(i - 1))
s += i * np.log(q) + _log_comb(alpha, i)
log_a = _log_add(log_a,s) log_a = _log_add(log_a,s)
return float(log_a) return float(log_a)
else: else:
@ -494,7 +497,7 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
for i in range(2, alpha + 1): for i in range(2, alpha + 1):
if i == 2: if i == 2:
s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum( s = 2 * np.log(q) + _log_comb(alpha,2) + np.minimum(
np.log(4) + func(2.0) + np.log(1 - np.exp(-func(2.0))), func(2.0) + np.log(2)) np.log(4) + log_f2m1, func(2.0) + np.log(2))
else: else:
s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i) s = np.log(2) + cgf(i-1) + i*np.log(q) + _log_comb(alpha, i)
log_a = _log_add(log_a, s) log_a = _log_add(log_a, s)
@ -580,3 +583,4 @@ def compute_rdp_from_ledger(ledger, orders):
total_rdp += compute_rdp( total_rdp += compute_rdp(
sample.selection_probability, effective_z, 1, orders) sample.selection_probability, effective_z, 1, orders)
return total_rdp return total_rdp

View file

@ -104,9 +104,9 @@ class TestGaussianMoments(parameterized.TestCase):
def test_compute_rdp_sequence_without_replacement(self): def test_compute_rdp_sequence_without_replacement(self):
rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(0.01, 2.5, 50, rdp_vec = rdp_accountant.compute_rdp_sample_without_replacement(0.01, 2.5, 50,
[1.001, 1.5, 2.5, 5, 50, 100, np.inf]) [1.001, 1.5, 2.5, 5, 50, 100, 256, 512, 1024, np.inf])
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(
rdp_vec, [0.003470,0.003470, 0.004638, 0.0087633, 0.09847, 167.766388, np.inf], rdp_vec, [0.003470, 0.003470, 0.004638, 0.0087633, 0.098474, 167.766388, 792.838516, 1817.35871, 3865.55029, np.inf],
delta=1e-5) delta=1e-5)
def test_compute_rdp_sequence(self): def test_compute_rdp_sequence(self):