forked from 626_privacy/tensorflow_privacy
Fix numerical instability in computing A(alpha) for very large integer alpha.
Tested that new implementation agrees with existing implementation on all small integers but also scales to 10^6. PiperOrigin-RevId: 348492489
This commit is contained in:
parent
276d2d74d5
commit
e4f9794542
1 changed files with 6 additions and 2 deletions
|
@ -85,6 +85,11 @@ def _log_print(logx):
|
|||
return "exp({})".format(logx)
|
||||
|
||||
|
||||
def _log_comb(n, k):
|
||||
return (special.gammaln(n + 1) -
|
||||
special.gammaln(k + 1) - special.gammaln(n - k + 1))
|
||||
|
||||
|
||||
def _compute_log_a_int(q, sigma, alpha):
|
||||
"""Compute log(A_alpha) for integer alpha. 0 < q < 1."""
|
||||
assert isinstance(alpha, six.integer_types)
|
||||
|
@ -94,8 +99,7 @@ def _compute_log_a_int(q, sigma, alpha):
|
|||
|
||||
for i in range(alpha + 1):
|
||||
log_coef_i = (
|
||||
math.log(special.binom(alpha, i)) + i * math.log(q) +
|
||||
(alpha - i) * math.log(1 - q))
|
||||
_log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q))
|
||||
|
||||
s = log_coef_i + (i * i - i) / (2 * (sigma**2))
|
||||
log_a = _log_add(log_a, s)
|
||||
|
|
Loading…
Reference in a new issue