Deprecates implementations of RDP accounting from tensorflow_privacy in favor of differential_privacy.
PiperOrigin-RevId: 443177278
This commit is contained in:
parent
ee35642b90
commit
868cf54470
3 changed files with 113 additions and 500 deletions
|
@ -61,6 +61,11 @@ py_library(
|
|||
srcs = ["rdp_accountant.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_google_differential_py//python/dp_accounting:dp_event",
|
||||
"@com_google_differential_py//python/dp_accounting:privacy_accountant",
|
||||
"@com_google_differential_py//python/dp_accounting/rdp:rdp_privacy_accountant",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -12,7 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""RDP analysis of the Sampled Gaussian Mechanism.
|
||||
"""(Deprecated) RDP analysis of the Sampled Gaussian Mechanism.
|
||||
|
||||
The functions in this package have been superseded by more general accounting
|
||||
mechanisms in Google's `differential_privacy` package. These functions may at
|
||||
some future date be removed.
|
||||
|
||||
Functionality for computing Renyi differential privacy (RDP) of an additive
|
||||
Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods:
|
||||
|
@ -37,342 +41,50 @@ The example code would be:
|
|||
eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta)
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
||||
########################
|
||||
# LOG-SPACE ARITHMETIC #
|
||||
########################
|
||||
from com_google_differential_py.python.dp_accounting import dp_event
|
||||
from com_google_differential_py.python.dp_accounting import privacy_accountant
|
||||
from com_google_differential_py.python.dp_accounting.rdp import rdp_privacy_accountant
|
||||
|
||||
|
||||
def _log_add(logx, logy):
|
||||
"""Add two numbers in the log space."""
|
||||
a, b = min(logx, logy), max(logx, logy)
|
||||
if a == -np.inf: # adding 0
|
||||
return b
|
||||
# Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
|
||||
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
|
||||
|
||||
|
||||
def _log_sub(logx, logy):
|
||||
"""Subtract two numbers in the log space. Answer must be non-negative."""
|
||||
if logx < logy:
|
||||
raise ValueError("The result of subtraction must be non-negative.")
|
||||
if logy == -np.inf: # subtracting 0
|
||||
return logx
|
||||
if logx == logy:
|
||||
return -np.inf # 0 is represented as -np.inf in the log space.
|
||||
|
||||
try:
|
||||
# Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
|
||||
return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
|
||||
except OverflowError:
|
||||
return logx
|
||||
|
||||
|
||||
def _log_sub_sign(logx, logy):
|
||||
"""Returns log(exp(logx)-exp(logy)) and its sign."""
|
||||
if logx > logy:
|
||||
s = True
|
||||
mag = logx + np.log(1 - np.exp(logy - logx))
|
||||
elif logx < logy:
|
||||
s = False
|
||||
mag = logy + np.log(1 - np.exp(logx - logy))
|
||||
else:
|
||||
s = True
|
||||
mag = -np.inf
|
||||
|
||||
return s, mag
|
||||
|
||||
|
||||
def _log_print(logx):
|
||||
"""Pretty print."""
|
||||
if logx < math.log(sys.float_info.max):
|
||||
return "{}".format(math.exp(logx))
|
||||
else:
|
||||
return "exp({})".format(logx)
|
||||
|
||||
|
||||
def _log_comb(n, k):
|
||||
return (special.gammaln(n + 1) - special.gammaln(k + 1) -
|
||||
special.gammaln(n - k + 1))
|
||||
|
||||
|
||||
def _compute_log_a_int(q, sigma, alpha):
|
||||
"""Compute log(A_alpha) for integer alpha. 0 < q < 1."""
|
||||
assert isinstance(alpha, int)
|
||||
|
||||
# Initialize with 0 in the log space.
|
||||
log_a = -np.inf
|
||||
|
||||
for i in range(alpha + 1):
|
||||
log_coef_i = (
|
||||
_log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q))
|
||||
|
||||
s = log_coef_i + (i * i - i) / (2 * (sigma**2))
|
||||
log_a = _log_add(log_a, s)
|
||||
|
||||
return float(log_a)
|
||||
|
||||
|
||||
def _compute_log_a_frac(q, sigma, alpha):
|
||||
"""Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
|
||||
# The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
|
||||
# initialized to 0 in the log space:
|
||||
log_a0, log_a1 = -np.inf, -np.inf
|
||||
i = 0
|
||||
|
||||
z0 = sigma**2 * math.log(1 / q - 1) + .5
|
||||
|
||||
while True: # do ... until loop
|
||||
coef = special.binom(alpha, i)
|
||||
log_coef = math.log(abs(coef))
|
||||
j = alpha - i
|
||||
|
||||
log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
|
||||
log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
|
||||
|
||||
log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
|
||||
log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
|
||||
|
||||
log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
|
||||
log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1
|
||||
|
||||
if coef > 0:
|
||||
log_a0 = _log_add(log_a0, log_s0)
|
||||
log_a1 = _log_add(log_a1, log_s1)
|
||||
else:
|
||||
log_a0 = _log_sub(log_a0, log_s0)
|
||||
log_a1 = _log_sub(log_a1, log_s1)
|
||||
|
||||
i += 1
|
||||
if max(log_s0, log_s1) < -30:
|
||||
break
|
||||
|
||||
return _log_add(log_a0, log_a1)
|
||||
|
||||
|
||||
def _compute_log_a(q, sigma, alpha):
|
||||
"""Compute log(A_alpha) for any positive finite alpha."""
|
||||
if float(alpha).is_integer():
|
||||
return _compute_log_a_int(q, sigma, int(alpha))
|
||||
else:
|
||||
return _compute_log_a_frac(q, sigma, alpha)
|
||||
|
||||
|
||||
def _log_erfc(x):
|
||||
"""Compute log(erfc(x)) with high accuracy for large x."""
|
||||
try:
|
||||
return math.log(2) + special.log_ndtr(-x * 2**.5)
|
||||
except NameError:
|
||||
# If log_ndtr is not available, approximate as follows:
|
||||
r = special.erfc(x)
|
||||
if r == 0.0:
|
||||
# Using the Laurent series at infinity for the tail of the erfc function:
|
||||
# erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5)
|
||||
# To verify in Mathematica:
|
||||
# Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}]
|
||||
return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 +
|
||||
.625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8)
|
||||
else:
|
||||
return math.log(r)
|
||||
|
||||
|
||||
def _compute_delta(orders, rdp, eps):
|
||||
"""Compute delta given a list of RDP values and target epsilon.
|
||||
def _compute_rdp_from_event(orders, event, count):
|
||||
"""Computes RDP from a DpEvent using RdpAccountant.
|
||||
|
||||
Args:
|
||||
orders: An array (or a scalar) of orders.
|
||||
rdp: A list (or a scalar) of RDP guarantees.
|
||||
eps: The target epsilon.
|
||||
orders: An array (or a scalar) of RDP orders.
|
||||
event: A DpEvent to compute the RDP of.
|
||||
count: The number of self-compositions.
|
||||
|
||||
Returns:
|
||||
Pair of (delta, optimal_order).
|
||||
|
||||
Raises:
|
||||
ValueError: If input is malformed.
|
||||
|
||||
The RDP at all orders. Can be `np.inf`.
|
||||
"""
|
||||
orders_vec = np.atleast_1d(orders)
|
||||
rdp_vec = np.atleast_1d(rdp)
|
||||
|
||||
if eps < 0:
|
||||
raise ValueError("Value of privacy loss bound epsilon must be >=0.")
|
||||
if len(orders_vec) != len(rdp_vec):
|
||||
raise ValueError("Input lists must have the same length.")
|
||||
|
||||
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
|
||||
# delta = min( np.exp((rdp_vec - eps) * (orders_vec - 1)) )
|
||||
|
||||
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4):
|
||||
logdeltas = [] # work in log space to avoid overflows
|
||||
for (a, r) in zip(orders_vec, rdp_vec):
|
||||
if a < 1:
|
||||
raise ValueError("Renyi divergence order must be >=1.")
|
||||
if r < 0:
|
||||
raise ValueError("Renyi divergence must be >=0.")
|
||||
# For small alpha, we are better of with bound via KL divergence:
|
||||
# delta <= sqrt(1-exp(-KL)).
|
||||
# Take a min of the two bounds.
|
||||
logdelta = 0.5 * math.log1p(-math.exp(-r))
|
||||
if a > 1.01:
|
||||
# This bound is not numerically stable as alpha->1.
|
||||
# Thus we have a min value for alpha.
|
||||
# The bound is also not useful for small alpha, so doesn't matter.
|
||||
rdp_bound = (a - 1) * (r - eps + math.log1p(-1 / a)) - math.log(a)
|
||||
logdelta = min(logdelta, rdp_bound)
|
||||
|
||||
logdeltas.append(logdelta)
|
||||
|
||||
idx_opt = np.argmin(logdeltas)
|
||||
return min(math.exp(logdeltas[idx_opt]), 1.), orders_vec[idx_opt]
|
||||
|
||||
|
||||
def _compute_eps(orders, rdp, delta):
|
||||
"""Compute epsilon given a list of RDP values and target delta.
|
||||
|
||||
Args:
|
||||
orders: An array (or a scalar) of orders.
|
||||
rdp: A list (or a scalar) of RDP guarantees.
|
||||
delta: The target delta.
|
||||
|
||||
Returns:
|
||||
Pair of (eps, optimal_order).
|
||||
|
||||
Raises:
|
||||
ValueError: If input is malformed.
|
||||
|
||||
"""
|
||||
orders_vec = np.atleast_1d(orders)
|
||||
rdp_vec = np.atleast_1d(rdp)
|
||||
|
||||
if delta <= 0:
|
||||
raise ValueError("Privacy failure probability bound delta must be >0.")
|
||||
if len(orders_vec) != len(rdp_vec):
|
||||
raise ValueError("Input lists must have the same length.")
|
||||
|
||||
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
|
||||
# eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) )
|
||||
|
||||
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4).
|
||||
# Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1).
|
||||
eps_vec = []
|
||||
for (a, r) in zip(orders_vec, rdp_vec):
|
||||
if a < 1:
|
||||
raise ValueError("Renyi divergence order must be >=1.")
|
||||
if r < 0:
|
||||
raise ValueError("Renyi divergence must be >=0.")
|
||||
|
||||
if delta**2 + math.expm1(-r) >= 0:
|
||||
# In this case, we can simply bound via KL divergence:
|
||||
# delta <= sqrt(1-exp(-KL)).
|
||||
eps = 0 # No need to try further computation if we have eps = 0.
|
||||
elif a > 1.01:
|
||||
# This bound is not numerically stable as alpha->1.
|
||||
# Thus we have a min value of alpha.
|
||||
# The bound is also not useful for small alpha, so doesn't matter.
|
||||
eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1)
|
||||
else:
|
||||
# In this case we can't do anything. E.g., asking for delta = 0.
|
||||
eps = np.inf
|
||||
eps_vec.append(eps)
|
||||
|
||||
idx_opt = np.argmin(eps_vec)
|
||||
return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
|
||||
|
||||
|
||||
def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
||||
"""Replaces the first n-1 dims of vec with the log of abs difference operator.
|
||||
|
||||
Args:
|
||||
vec: numpy array of floats with size larger than 'n'
|
||||
signs: Optional numpy array of bools with the same size as vec in case one
|
||||
needs to compute partial differences vec and signs jointly describe a
|
||||
vector of real numbers' sign and abs in log scale.
|
||||
n: Optonal upper bound on number of differences to compute. If negative, all
|
||||
differences are computed.
|
||||
|
||||
Returns:
|
||||
The first n-1 dimension of vec and signs will store the log-abs and sign of
|
||||
the difference.
|
||||
|
||||
Raises:
|
||||
ValueError: If input is malformed.
|
||||
"""
|
||||
|
||||
assert vec.shape == signs.shape
|
||||
if n < 0:
|
||||
n = np.max(vec.shape) - 1
|
||||
if isinstance(event, dp_event.SampledWithoutReplacementDpEvent):
|
||||
neighboring_relation = privacy_accountant.NeighboringRelation.REPLACE_ONE
|
||||
elif isinstance(event, dp_event.SingleEpochTreeAggregationDpEvent):
|
||||
neighboring_relation = privacy_accountant.NeighboringRelation.REPLACE_SPECIAL
|
||||
else:
|
||||
assert np.max(vec.shape) >= n + 1
|
||||
for j in range(0, n, 1):
|
||||
if signs[j] == signs[j + 1]: # When the signs are the same
|
||||
# if the signs are both positive, then we can just use the standard one
|
||||
signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j])
|
||||
# otherwise, we do that but toggle the sign
|
||||
if not signs[j + 1]:
|
||||
signs[j] = ~signs[j]
|
||||
else: # When the signs are different.
|
||||
vec[j] = _log_add(vec[j], vec[j + 1])
|
||||
signs[j] = signs[j + 1]
|
||||
neighboring_relation = privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE
|
||||
|
||||
accountant = rdp_privacy_accountant.RdpAccountant(orders_vec,
|
||||
neighboring_relation)
|
||||
accountant.compose(event, count)
|
||||
rdp = accountant._rdp # pylint: disable=protected-access
|
||||
|
||||
def _get_forward_diffs(fun, n):
|
||||
"""Computes up to nth order forward difference evaluated at 0.
|
||||
|
||||
See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf
|
||||
|
||||
Args:
|
||||
fun: Function to compute forward differences of.
|
||||
n: Number of differences to compute.
|
||||
|
||||
Returns:
|
||||
Pair (deltas, signs_deltas) of the log deltas and their signs.
|
||||
"""
|
||||
func_vec = np.zeros(n + 3)
|
||||
signs_func_vec = np.ones(n + 3, dtype=bool)
|
||||
|
||||
# ith coordinate of deltas stores log(abs(ith order discrete derivative))
|
||||
deltas = np.zeros(n + 2)
|
||||
signs_deltas = np.zeros(n + 2, dtype=bool)
|
||||
for i in range(1, n + 3, 1):
|
||||
func_vec[i] = fun(1.0 * (i - 1))
|
||||
for i in range(0, n + 2, 1):
|
||||
# Diff in log scale
|
||||
_stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i)
|
||||
deltas[i] = func_vec[0]
|
||||
signs_deltas[i] = signs_func_vec[0]
|
||||
return deltas, signs_deltas
|
||||
|
||||
|
||||
def _compute_rdp(q, sigma, alpha):
|
||||
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||
|
||||
Args:
|
||||
q: The sampling rate.
|
||||
sigma: The std of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
RDP at alpha, can be np.inf.
|
||||
"""
|
||||
if q == 0:
|
||||
return 0
|
||||
|
||||
if q == 1.:
|
||||
return alpha / (2 * sigma**2)
|
||||
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
|
||||
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
|
||||
if np.isscalar(orders):
|
||||
return rdp[0]
|
||||
else:
|
||||
return rdp
|
||||
|
||||
|
||||
def compute_rdp(q, noise_multiplier, steps, orders):
|
||||
"""Computes RDP of the Sampled Gaussian Mechanism.
|
||||
"""(Deprecated) Computes RDP of the Sampled Gaussian Mechanism.
|
||||
|
||||
This function has been superseded by more general accounting mechanisms in
|
||||
Google's `differential_privacy` package. It may at some future date be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
q: The sampling rate.
|
||||
|
@ -384,17 +96,18 @@ def compute_rdp(q, noise_multiplier, steps, orders):
|
|||
Returns:
|
||||
The RDPs at all orders. Can be `np.inf`.
|
||||
"""
|
||||
if np.isscalar(orders):
|
||||
rdp = _compute_rdp(q, noise_multiplier, orders)
|
||||
else:
|
||||
rdp = np.array(
|
||||
[_compute_rdp(q, noise_multiplier, order) for order in orders])
|
||||
event = dp_event.PoissonSampledDpEvent(
|
||||
q, dp_event.GaussianDpEvent(noise_multiplier))
|
||||
|
||||
return rdp * steps
|
||||
return _compute_rdp_from_event(orders, event, steps)
|
||||
|
||||
|
||||
def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
||||
"""Compute RDP of Gaussian Mechanism using sampling without replacement.
|
||||
"""(Deprecated) Compute RDP of Gaussian Mechanism sampling w/o replacement.
|
||||
|
||||
This function has been superseded by more general accounting mechanisms in
|
||||
Google's `differential_privacy` package. It may at some future date be
|
||||
removed.
|
||||
|
||||
This function applies to the following schemes:
|
||||
1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n.
|
||||
|
@ -416,129 +129,19 @@ def compute_rdp_sample_without_replacement(q, noise_multiplier, steps, orders):
|
|||
Returns:
|
||||
The RDPs at all orders, can be np.inf.
|
||||
"""
|
||||
if np.isscalar(orders):
|
||||
rdp = _compute_rdp_sample_without_replacement_scalar(
|
||||
q, noise_multiplier, orders)
|
||||
else:
|
||||
rdp = np.array([
|
||||
_compute_rdp_sample_without_replacement_scalar(q, noise_multiplier,
|
||||
order)
|
||||
for order in orders
|
||||
])
|
||||
event = dp_event.SampledWithoutReplacementDpEvent(
|
||||
1, q, dp_event.GaussianDpEvent(noise_multiplier))
|
||||
|
||||
return rdp * steps
|
||||
|
||||
|
||||
def _compute_rdp_sample_without_replacement_scalar(q, sigma, alpha):
|
||||
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||
|
||||
Args:
|
||||
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||
sigma: The std of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
RDP at alpha, can be np.inf.
|
||||
"""
|
||||
|
||||
assert (q <= 1) and (q >= 0) and (alpha >= 1)
|
||||
|
||||
if q == 0:
|
||||
return 0
|
||||
|
||||
if q == 1.:
|
||||
return alpha / (2 * sigma**2)
|
||||
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
|
||||
if float(alpha).is_integer():
|
||||
return _compute_rdp_sample_without_replacement_int(q, sigma, alpha) / (
|
||||
alpha - 1)
|
||||
else:
|
||||
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate
|
||||
# the CGF and obtain an upper bound
|
||||
alpha_f = math.floor(alpha)
|
||||
alpha_c = math.ceil(alpha)
|
||||
|
||||
x = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_f)
|
||||
y = _compute_rdp_sample_without_replacement_int(q, sigma, alpha_c)
|
||||
t = alpha - alpha_f
|
||||
return ((1 - t) * x + t * y) / (alpha - 1)
|
||||
|
||||
|
||||
def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
|
||||
"""Compute log(A_alpha) for integer alpha, subsampling without replacement.
|
||||
|
||||
When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly,
|
||||
otherwise compute the bound with Stirling approximation.
|
||||
|
||||
Args:
|
||||
q: The sampling proportion = m / n. Assume m is an integer <= n.
|
||||
sigma: The std of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
RDP at alpha, can be np.inf.
|
||||
"""
|
||||
|
||||
max_alpha = 256
|
||||
assert isinstance(alpha, int)
|
||||
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
elif alpha == 1:
|
||||
return 0
|
||||
|
||||
def cgf(x):
|
||||
# Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2)
|
||||
return x * 1.0 * (x + 1) / (2.0 * sigma**2)
|
||||
|
||||
def func(x):
|
||||
# Return the rdp of Gaussian mechanism
|
||||
return 1.0 * x / (2.0 * sigma**2)
|
||||
|
||||
# Initialize with 1 in the log space.
|
||||
log_a = 0
|
||||
# Calculates the log term when alpha = 2
|
||||
log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
|
||||
if alpha <= max_alpha:
|
||||
# We need forward differences of exp(cgf)
|
||||
# The following line is the numerically stable way of implementing it.
|
||||
# The output is in polar form with logarithmic magnitude
|
||||
deltas, _ = _get_forward_diffs(cgf, alpha)
|
||||
# Compute the bound exactly requires book keeping of O(alpha**2)
|
||||
|
||||
for i in range(2, alpha + 1):
|
||||
if i == 2:
|
||||
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
|
||||
np.log(4) + log_f2m1,
|
||||
func(2.0) + np.log(2))
|
||||
elif i > 2:
|
||||
delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1]
|
||||
delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1]
|
||||
s = np.log(4) + 0.5 * (delta_lo + delta_hi)
|
||||
s = np.minimum(s, np.log(2) + cgf(i - 1))
|
||||
s += i * np.log(q) + _log_comb(alpha, i)
|
||||
log_a = _log_add(log_a, s)
|
||||
return float(log_a)
|
||||
else:
|
||||
# Compute the bound with stirling approximation. Everything is O(x) now.
|
||||
for i in range(2, alpha + 1):
|
||||
if i == 2:
|
||||
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
|
||||
np.log(4) + log_f2m1,
|
||||
func(2.0) + np.log(2))
|
||||
else:
|
||||
s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i)
|
||||
log_a = _log_add(log_a, s)
|
||||
|
||||
return log_a
|
||||
return _compute_rdp_from_event(orders, event, steps)
|
||||
|
||||
|
||||
def compute_heterogeneous_rdp(sampling_probabilities, noise_multipliers,
|
||||
steps_list, orders):
|
||||
"""Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.
|
||||
"""(Deprecated) Computes RDP of Heteregoneous Sampled Gaussian Mechanisms.
|
||||
|
||||
This function has been superseded by more general accounting mechanisms in
|
||||
Google's `differential_privacy` package. It may at some future date be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
sampling_probabilities: A list containing the sampling rates.
|
||||
|
@ -563,7 +166,11 @@ def compute_heterogeneous_rdp(sampling_probabilities, noise_multipliers,
|
|||
|
||||
|
||||
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
||||
"""Computes delta (or eps) for given eps (or delta) from RDP values.
|
||||
"""(Deprecated) Computes delta or eps from RDP values.
|
||||
|
||||
This function has been superseded by more general accounting mechanisms in
|
||||
Google's `differential_privacy` package. It may at some future date be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
orders: An array (or a scalar) of RDP orders.
|
||||
|
@ -588,9 +195,12 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
|
|||
raise ValueError(
|
||||
"Exactly one out of eps and delta must be None. (None is).")
|
||||
|
||||
accountant = rdp_privacy_accountant.RdpAccountant(orders)
|
||||
accountant._rdp = rdp # pylint: disable=protected-access
|
||||
|
||||
if target_eps is not None:
|
||||
delta, opt_order = _compute_delta(orders, rdp, target_eps)
|
||||
delta, opt_order = accountant.get_delta_and_optimal_order(target_eps)
|
||||
return target_eps, delta, opt_order
|
||||
else:
|
||||
eps, opt_order = _compute_eps(orders, rdp, target_delta)
|
||||
eps, opt_order = accountant.get_epsilon_and_optimal_order(target_delta)
|
||||
return eps, target_delta, opt_order
|
||||
|
|
|
@ -23,51 +23,58 @@ import tensorflow as tf
|
|||
|
||||
from tensorflow_privacy.privacy.analysis import rdp_accountant
|
||||
|
||||
#################################
|
||||
# HELPER FUNCTIONS: #
|
||||
# Exact computations using #
|
||||
# multi-precision arithmetic. #
|
||||
#################################
|
||||
|
||||
|
||||
def _log_float_mp(x):
|
||||
# Convert multi-precision input to float log space.
|
||||
if x >= sys.float_info.min:
|
||||
return float(mpmath.log(x))
|
||||
else:
|
||||
return -np.inf
|
||||
|
||||
|
||||
def _integral_mp(fn, bounds=(-mpmath.inf, mpmath.inf)):
|
||||
integral, _ = mpmath.quad(fn, bounds, error=True, maxdegree=8)
|
||||
return integral
|
||||
|
||||
|
||||
def _distributions_mp(sigma, q):
|
||||
|
||||
def _mu0(x):
|
||||
return mpmath.npdf(x, mu=0, sigma=sigma)
|
||||
|
||||
def _mu1(x):
|
||||
return mpmath.npdf(x, mu=1, sigma=sigma)
|
||||
|
||||
def _mu(x):
|
||||
return (1 - q) * _mu0(x) + q * _mu1(x)
|
||||
|
||||
return _mu0, _mu # Closure!
|
||||
|
||||
|
||||
def _mu1_over_mu0(x, sigma):
|
||||
# Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x.
|
||||
return mpmath.exp((2 * x - 1) / (2 * sigma**2))
|
||||
|
||||
|
||||
def _mu_over_mu0(x, q, sigma):
|
||||
return (1 - q) + q * _mu1_over_mu0(x, sigma)
|
||||
|
||||
|
||||
def _compute_a_mp(sigma, q, alpha):
|
||||
"""Compute A_alpha for arbitrary alpha by numerical integration."""
|
||||
mu0, _ = _distributions_mp(sigma, q)
|
||||
a_alpha_fn = lambda z: mu0(z) * _mu_over_mu0(z, q, sigma)**alpha
|
||||
a_alpha = _integral_mp(a_alpha_fn)
|
||||
return a_alpha
|
||||
|
||||
|
||||
class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
|
||||
#################################
|
||||
# HELPER FUNCTIONS: #
|
||||
# Exact computations using #
|
||||
# multi-precision arithmetic. #
|
||||
#################################
|
||||
|
||||
def _log_float_mp(self, x):
|
||||
# Convert multi-precision input to float log space.
|
||||
if x >= sys.float_info.min:
|
||||
return float(mpmath.log(x))
|
||||
else:
|
||||
return -np.inf
|
||||
|
||||
def _integral_mp(self, fn, bounds=(-mpmath.inf, mpmath.inf)):
|
||||
integral, _ = mpmath.quad(fn, bounds, error=True, maxdegree=8)
|
||||
return integral
|
||||
|
||||
def _distributions_mp(self, sigma, q):
|
||||
|
||||
def _mu0(x):
|
||||
return mpmath.npdf(x, mu=0, sigma=sigma)
|
||||
|
||||
def _mu1(x):
|
||||
return mpmath.npdf(x, mu=1, sigma=sigma)
|
||||
|
||||
def _mu(x):
|
||||
return (1 - q) * _mu0(x) + q * _mu1(x)
|
||||
|
||||
return _mu0, _mu # Closure!
|
||||
|
||||
def _mu1_over_mu0(self, x, sigma):
|
||||
# Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x.
|
||||
return mpmath.exp((2 * x - 1) / (2 * sigma**2))
|
||||
|
||||
def _mu_over_mu0(self, x, q, sigma):
|
||||
return (1 - q) + q * self._mu1_over_mu0(x, sigma)
|
||||
|
||||
def _compute_a_mp(self, sigma, q, alpha):
|
||||
"""Compute A_alpha for arbitrary alpha by numerical integration."""
|
||||
mu0, _ = self._distributions_mp(sigma, q)
|
||||
a_alpha_fn = lambda z: mu0(z) * self._mu_over_mu0(z, q, sigma)**alpha
|
||||
a_alpha = self._integral_mp(a_alpha_fn)
|
||||
return a_alpha
|
||||
|
||||
# TEST ROUTINES
|
||||
def test_compute_heterogeneous_rdp_different_sampling_probabilities(self):
|
||||
|
@ -152,15 +159,6 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
|
|||
'order': 256.1
|
||||
})
|
||||
|
||||
# pylint:disable=undefined-variable
|
||||
@parameterized.parameters(p for p in params)
|
||||
def test_compute_log_a_equals_mp(self, q, sigma, order):
|
||||
# Compare the cheap computation of log(A) with an expensive, multi-precision
|
||||
# computation.
|
||||
log_a = rdp_accountant._compute_log_a(q, sigma, order)
|
||||
log_a_mp = self._log_float_mp(self._compute_a_mp(sigma, q, order))
|
||||
np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4)
|
||||
|
||||
def test_get_privacy_spent_check_target_delta(self):
|
||||
orders = range(2, 33)
|
||||
rdp = [1.1 for o in orders] # Constant corresponds to pure DP.
|
||||
|
|
Loading…
Reference in a new issue