[Py Accounting] Add typing annotations in RDP accounting.
PiperOrigin-RevId: 435703861
This commit is contained in:
parent
adde2064dd
commit
d21e492be6
1 changed files with 41 additions and 27 deletions
|
@ -15,7 +15,7 @@
|
|||
"""Privacy accountant that uses Renyi differential privacy."""
|
||||
|
||||
import math
|
||||
from typing import Collection, Optional, Union
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
@ -26,7 +26,7 @@ from tensorflow_privacy.privacy.analysis import privacy_accountant
|
|||
NeighborRel = privacy_accountant.NeighboringRelation
|
||||
|
||||
|
||||
def _log_add(logx, logy):
|
||||
def _log_add(logx: float, logy: float) -> float:
|
||||
"""Adds two numbers in the log space."""
|
||||
a, b = min(logx, logy), max(logx, logy)
|
||||
if a == -np.inf: # adding 0
|
||||
|
@ -35,7 +35,7 @@ def _log_add(logx, logy):
|
|||
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
|
||||
|
||||
|
||||
def _log_sub(logx, logy):
|
||||
def _log_sub(logx: float, logy: float) -> float:
|
||||
"""Subtracts two numbers in the log space. Answer must be non-negative."""
|
||||
if logx < logy:
|
||||
raise ValueError('The result of subtraction must be non-negative.')
|
||||
|
@ -51,7 +51,7 @@ def _log_sub(logx, logy):
|
|||
return logx
|
||||
|
||||
|
||||
def _log_sub_sign(logx, logy):
|
||||
def _log_sub_sign(logx: float, logy: float) -> Tuple[bool, float]:
|
||||
"""Returns log(exp(logx)-exp(logy)) and its sign."""
|
||||
if logx > logy:
|
||||
s = True
|
||||
|
@ -66,15 +66,14 @@ def _log_sub_sign(logx, logy):
|
|||
return s, mag
|
||||
|
||||
|
||||
def _log_comb(n, k):
|
||||
def _log_comb(n: int, k: int) -> float:
|
||||
"""Computes log of binomial coefficient."""
|
||||
return (special.gammaln(n + 1) - special.gammaln(k + 1) -
|
||||
special.gammaln(n - k + 1))
|
||||
|
||||
|
||||
def _compute_log_a_int(q, sigma, alpha):
|
||||
def _compute_log_a_int(q: float, sigma: float, alpha: int) -> float:
|
||||
"""Computes log(A_alpha) for integer alpha, 0 < q < 1."""
|
||||
assert isinstance(alpha, int)
|
||||
|
||||
# Initialize with 0 in the log space.
|
||||
log_a = -np.inf
|
||||
|
@ -89,7 +88,7 @@ def _compute_log_a_int(q, sigma, alpha):
|
|||
return float(log_a)
|
||||
|
||||
|
||||
def _compute_log_a_frac(q, sigma, alpha):
|
||||
def _compute_log_a_frac(q: float, sigma: float, alpha: float) -> float:
|
||||
"""Computes log(A_alpha) for fractional alpha, 0 < q < 1."""
|
||||
# The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
|
||||
# initialized to 0 in the log space:
|
||||
|
@ -126,7 +125,7 @@ def _compute_log_a_frac(q, sigma, alpha):
|
|||
return _log_add(log_a0, log_a1)
|
||||
|
||||
|
||||
def _log_erfc(x):
|
||||
def _log_erfc(x: float) -> float:
|
||||
"""Computes log(erfc(x)) with high accuracy for large x."""
|
||||
try:
|
||||
return math.log(2) + special.log_ndtr(-x * 2**.5)
|
||||
|
@ -144,7 +143,8 @@ def _log_erfc(x):
|
|||
return math.log(r)
|
||||
|
||||
|
||||
def _compute_delta(orders, rdp, epsilon):
|
||||
def _compute_delta(orders: Sequence[float], rdp: Sequence[float],
|
||||
epsilon: float) -> float:
|
||||
"""Compute delta given a list of RDP values and target epsilon.
|
||||
|
||||
Args:
|
||||
|
@ -193,7 +193,8 @@ def _compute_delta(orders, rdp, epsilon):
|
|||
return min(math.exp(np.min(logdeltas)), 1.)
|
||||
|
||||
|
||||
def _compute_epsilon(orders, rdp, delta):
|
||||
def _compute_epsilon(orders: Sequence[float], rdp: Sequence[float],
|
||||
delta: float) -> float:
|
||||
"""Compute epsilon given a list of RDP values and target delta.
|
||||
|
||||
Args:
|
||||
|
@ -249,7 +250,9 @@ def _compute_epsilon(orders, rdp, delta):
|
|||
return max(0, np.min(eps))
|
||||
|
||||
|
||||
def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
||||
def _stable_inplace_diff_in_log(vec: np.ndarray,
|
||||
signs: np.ndarray,
|
||||
n: Optional[int] = None):
|
||||
"""Replaces the first n-1 dims of vec with the log of abs difference operator.
|
||||
|
||||
Args:
|
||||
|
@ -257,7 +260,7 @@ def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
|||
signs: Optional numpy array of bools with the same size as vec in case one
|
||||
needs to compute partial differences vec and signs jointly describe a
|
||||
vector of real numbers' sign and abs in log scale.
|
||||
n: Optonal upper bound on number of differences to compute. If negative, all
|
||||
n: Optonal upper bound on number of differences to compute. If None, all
|
||||
differences are computed.
|
||||
|
||||
Returns:
|
||||
|
@ -268,8 +271,11 @@ def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
|||
ValueError: If input is malformed.
|
||||
"""
|
||||
|
||||
assert vec.shape == signs.shape
|
||||
if n < 0:
|
||||
if vec.shape != signs.shape:
|
||||
raise ValueError('Shape of vec and signs do not match.')
|
||||
if signs.dtype != bool:
|
||||
raise ValueError('signs must be of type bool')
|
||||
if n is None:
|
||||
n = np.max(vec.shape) - 1
|
||||
else:
|
||||
assert np.max(vec.shape) >= n + 1
|
||||
|
@ -285,7 +291,8 @@ def _stable_inplace_diff_in_log(vec, signs, n=-1):
|
|||
signs[j] = signs[j + 1]
|
||||
|
||||
|
||||
def _get_forward_diffs(fun, n):
|
||||
def _get_forward_diffs(fun: Callable[[float], float],
|
||||
n: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Computes up to nth order forward difference evaluated at 0.
|
||||
|
||||
See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf
|
||||
|
@ -313,14 +320,17 @@ def _get_forward_diffs(fun, n):
|
|||
return deltas, signs_deltas
|
||||
|
||||
|
||||
def _compute_log_a(q, noise_multiplier, alpha):
|
||||
def _compute_log_a(q: float, noise_multiplier: float,
|
||||
alpha: Union[int, float]) -> float:
|
||||
if float(alpha).is_integer():
|
||||
return _compute_log_a_int(q, noise_multiplier, int(alpha))
|
||||
else:
|
||||
return _compute_log_a_frac(q, noise_multiplier, alpha)
|
||||
|
||||
|
||||
def _compute_rdp_poisson_subsampled_gaussian(q, noise_multiplier, orders):
|
||||
def _compute_rdp_poisson_subsampled_gaussian(
|
||||
q: float, noise_multiplier: float,
|
||||
orders: Sequence[float]) -> Union[float, np.ndarray]:
|
||||
"""Computes RDP of the Poisson sampled Gaussian mechanism.
|
||||
|
||||
Args:
|
||||
|
@ -348,7 +358,9 @@ def _compute_rdp_poisson_subsampled_gaussian(q, noise_multiplier, orders):
|
|||
return np.array([compute_one_order(q, order) for order in orders])
|
||||
|
||||
|
||||
def _compute_rdp_sample_wor_gaussian(q, noise_multiplier, orders):
|
||||
def _compute_rdp_sample_wor_gaussian(
|
||||
q: float, noise_multiplier: float,
|
||||
orders: Sequence[float]) -> Union[float, np.ndarray]:
|
||||
"""Computes RDP of Gaussian mechanism using sampling without replacement.
|
||||
|
||||
This function applies to the following schemes:
|
||||
|
@ -376,7 +388,8 @@ def _compute_rdp_sample_wor_gaussian(q, noise_multiplier, orders):
|
|||
])
|
||||
|
||||
|
||||
def _compute_rdp_sample_wor_gaussian_scalar(q, sigma, alpha):
|
||||
def _compute_rdp_sample_wor_gaussian_scalar(q: float, sigma: float,
|
||||
alpha: Union[float, int]) -> float:
|
||||
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
|
||||
|
||||
Args:
|
||||
|
@ -414,7 +427,8 @@ def _compute_rdp_sample_wor_gaussian_scalar(q, sigma, alpha):
|
|||
return ((1 - t) * x + t * y) / (alpha - 1)
|
||||
|
||||
|
||||
def _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha):
|
||||
def _compute_rdp_sample_wor_gaussian_int(q: float, sigma: float,
|
||||
alpha: int) -> float:
|
||||
"""Compute log(A_alpha) for integer alpha, subsampling without replacement.
|
||||
|
||||
When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly,
|
||||
|
@ -430,7 +444,6 @@ def _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha):
|
|||
"""
|
||||
|
||||
max_alpha = 256
|
||||
assert isinstance(alpha, int)
|
||||
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
|
@ -483,7 +496,8 @@ def _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha):
|
|||
return log_a
|
||||
|
||||
|
||||
def _effective_gaussian_noise_multiplier(event: dp_event.DpEvent):
|
||||
def _effective_gaussian_noise_multiplier(
|
||||
event: dp_event.DpEvent) -> Optional[float]:
|
||||
"""Determines the effective noise multiplier of nested structure of Gaussians.
|
||||
|
||||
A series of Gaussian queries on the same data can be reexpressed as a single
|
||||
|
@ -520,8 +534,8 @@ def _effective_gaussian_noise_multiplier(event: dp_event.DpEvent):
|
|||
|
||||
|
||||
def _compute_rdp_single_epoch_tree_aggregation(
|
||||
noise_multiplier: float, step_counts: Union[int, Collection[int]],
|
||||
orders: Collection[float]) -> Union[float, np.ndarray]:
|
||||
noise_multiplier: float, step_counts: Union[int, Sequence[int]],
|
||||
orders: Sequence[float]) -> Union[float, np.ndarray]:
|
||||
"""Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism.
|
||||
|
||||
This function implements the accounting when the tree is periodically
|
||||
|
@ -558,7 +572,7 @@ def _compute_rdp_single_epoch_tree_aggregation(
|
|||
if steps < 0:
|
||||
raise ValueError(f'Steps must be non-negative. Got {step_counts}')
|
||||
|
||||
max_depth = max(math.ceil(math.log2(steps + 1)) for steps in step_counts)
|
||||
max_depth = math.ceil(math.log2(max(step_counts) + 1))
|
||||
return np.array([a * max_depth / (2 * noise_multiplier**2) for a in orders])
|
||||
|
||||
|
||||
|
@ -567,7 +581,7 @@ class RdpAccountant(privacy_accountant.PrivacyAccountant):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
orders: Optional[Collection[float]] = None,
|
||||
orders: Optional[Sequence[float]] = None,
|
||||
neighboring_relation: NeighborRel = NeighborRel.ADD_OR_REMOVE_ONE,
|
||||
):
|
||||
super().__init__(neighboring_relation)
|
||||
|
|
Loading…
Reference in a new issue