fix imports for v1 and make the versioning more explicit through LooseVersion

PiperOrigin-RevId: 249732562
This commit is contained in:
Nicolas Papernot 2019-05-23 15:56:48 -07:00 committed by A. Unique TensorFlower
parent 0efb23afcb
commit a06bc6c99b
4 changed files with 18 additions and 17 deletions

View file

@ -27,9 +27,9 @@ from privacy.dp_query import gaussian_query
def make_optimizer_class(cls):
"""Constructs a DP optimizer class from an existing one."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__
parent_code = tf.train.Optimizer.compute_gradients.__code__
child_code = cls.compute_gradients.__code__
GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP # pylint: disable=invalid-name
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
else:
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
@ -213,12 +213,11 @@ def make_gaussian_optimizer_class(cls):
return DPGaussianOptimizerClass
# Compatibility with tf 1 and 2 APIs
try:
AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer
AdamOptimizer = tf.compat.v1.train.AdamOptimizer
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
except: # pylint: disable=bare-except
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
AdagradOptimizer = tf.train.AdagradOptimizer
AdamOptimizer = tf.train.AdamOptimizer
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
else:
AdagradOptimizer = tf.optimizers.Adagrad
AdamOptimizer = tf.optimizers.Adam
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name

View file

@ -21,6 +21,8 @@ from __future__ import print_function
from absl import app
from absl import flags
from distutils.version import LooseVersion
import numpy as np
import tensorflow as tf
@ -29,10 +31,9 @@ from privacy.analysis.rdp_accountant import compute_rdp_from_ledger
from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer
# Compatibility with tf 1 and 2 APIs
try:
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
except: # pylint: disable=bare-except
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
FLAGS = flags.FLAGS

View file

@ -30,7 +30,7 @@ from privacy.dp_query.gaussian_query import GaussianAverageQuery
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
tf.compat.v1.enable_eager_execution()
else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name

View file

@ -20,6 +20,8 @@ from __future__ import print_function
from absl import app
from absl import flags
from distutils.version import LooseVersion
import numpy as np
import tensorflow as tf
@ -28,10 +30,9 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.dp_query.gaussian_query import GaussianAverageQuery
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
# Compatibility with tf 1 and 2 APIs
try:
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
except: # pylint: disable=bare-except
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
flags.DEFINE_boolean(