Normalize mpmath imports in TensorFlow Privacy to be more friendly with strict dependencies and lint.

PiperOrigin-RevId: 424681602
This commit is contained in:
Michael Reneer 2022-01-27 12:34:57 -08:00 committed by A. Unique TensorFlower
parent 81a11eb824
commit c36ce6d799

View file

@ -17,11 +17,7 @@ import math
import sys
from absl.testing import parameterized
from mpmath import exp
from mpmath import inf
from mpmath import log
from mpmath import npdf
from mpmath import quad
import mpmath
import numpy as np
import tensorflow as tf
@ -38,21 +34,21 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
def _log_float_mp(self, x):
# Convert multi-precision input to float log space.
if x >= sys.float_info.min:
return float(log(x))
return float(mpmath.log(x))
else:
return -np.inf
def _integral_mp(self, fn, bounds=(-inf, inf)):
integral, _ = quad(fn, bounds, error=True, maxdegree=8)
def _integral_mp(self, fn, bounds=(-mpmath.inf, mpmath.inf)):
integral, _ = mpmath.quad(fn, bounds, error=True, maxdegree=8)
return integral
def _distributions_mp(self, sigma, q):
def _mu0(x):
return npdf(x, mu=0, sigma=sigma)
return mpmath.npdf(x, mu=0, sigma=sigma)
def _mu1(x):
return npdf(x, mu=1, sigma=sigma)
return mpmath.npdf(x, mu=1, sigma=sigma)
def _mu(x):
return (1 - q) * _mu0(x) + q * _mu1(x)
@ -61,7 +57,7 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
def _mu1_over_mu0(self, x, sigma):
# Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x.
return exp((2 * x - 1) / (2 * sigma**2))
return mpmath.exp((2 * x - 1) / (2 * sigma**2))
def _mu_over_mu0(self, x, q, sigma):
return (1 - q) + q * self._mu1_over_mu0(x, sigma)