Normalize mpmath imports in TensorFlow Privacy to be more friendly with strict dependencies and lint.

PiperOrigin-RevId: 424681602
This commit is contained in:
Michael Reneer 2022-01-27 12:34:57 -08:00 committed by A. Unique TensorFlower
parent 81a11eb824
commit c36ce6d799

View file

@ -17,11 +17,7 @@ import math
import sys import sys
from absl.testing import parameterized from absl.testing import parameterized
from mpmath import exp import mpmath
from mpmath import inf
from mpmath import log
from mpmath import npdf
from mpmath import quad
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -38,21 +34,21 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
def _log_float_mp(self, x): def _log_float_mp(self, x):
# Convert multi-precision input to float log space. # Convert multi-precision input to float log space.
if x >= sys.float_info.min: if x >= sys.float_info.min:
return float(log(x)) return float(mpmath.log(x))
else: else:
return -np.inf return -np.inf
def _integral_mp(self, fn, bounds=(-inf, inf)): def _integral_mp(self, fn, bounds=(-mpmath.inf, mpmath.inf)):
integral, _ = quad(fn, bounds, error=True, maxdegree=8) integral, _ = mpmath.quad(fn, bounds, error=True, maxdegree=8)
return integral return integral
def _distributions_mp(self, sigma, q): def _distributions_mp(self, sigma, q):
def _mu0(x): def _mu0(x):
return npdf(x, mu=0, sigma=sigma) return mpmath.npdf(x, mu=0, sigma=sigma)
def _mu1(x): def _mu1(x):
return npdf(x, mu=1, sigma=sigma) return mpmath.npdf(x, mu=1, sigma=sigma)
def _mu(x): def _mu(x):
return (1 - q) * _mu0(x) + q * _mu1(x) return (1 - q) * _mu0(x) + q * _mu1(x)
@ -61,7 +57,7 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
def _mu1_over_mu0(self, x, sigma): def _mu1_over_mu0(self, x, sigma):
# Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x. # Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x.
return exp((2 * x - 1) / (2 * sigma**2)) return mpmath.exp((2 * x - 1) / (2 * sigma**2))
def _mu_over_mu0(self, x, q, sigma): def _mu_over_mu0(self, x, q, sigma):
return (1 - q) + q * self._mu1_over_mu0(x, sigma) return (1 - q) + q * self._mu1_over_mu0(x, sigma)