Fix numerical instability in computing A(alpha) for very large integer alpha.

Tested that new implementation agrees with existing implementation on all small integers but also scales to 10^6.

PiperOrigin-RevId: 348492489
This commit is contained in:
Galen Andrew 2020-12-21 10:51:46 -08:00 committed by A. Unique TensorFlower
parent 276d2d74d5
commit e4f9794542

View file

@ -85,6 +85,11 @@ def _log_print(logx):
return "exp({})".format(logx) return "exp({})".format(logx)
def _log_comb(n, k):
return (special.gammaln(n + 1) -
special.gammaln(k + 1) - special.gammaln(n - k + 1))
def _compute_log_a_int(q, sigma, alpha): def _compute_log_a_int(q, sigma, alpha):
"""Compute log(A_alpha) for integer alpha. 0 < q < 1.""" """Compute log(A_alpha) for integer alpha. 0 < q < 1."""
assert isinstance(alpha, six.integer_types) assert isinstance(alpha, six.integer_types)
@ -94,8 +99,7 @@ def _compute_log_a_int(q, sigma, alpha):
for i in range(alpha + 1): for i in range(alpha + 1):
log_coef_i = ( log_coef_i = (
math.log(special.binom(alpha, i)) + i * math.log(q) + _log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q))
(alpha - i) * math.log(1 - q))
s = log_coef_i + (i * i - i) / (2 * (sigma**2)) s = log_coef_i + (i * i - i) / (2 * (sigma**2))
log_a = _log_add(log_a, s) log_a = _log_add(log_a, s)