From a06bc6c99b89392d2e1809040a40937cfadc8765 Mon Sep 17 00:00:00 2001 From: Nicolas Papernot Date: Thu, 23 May 2019 15:56:48 -0700 Subject: [PATCH] fix imports for v1 and make the versioning more explicit through LooseVersion PiperOrigin-RevId: 249732562 --- privacy/optimizers/dp_optimizer.py | 15 +++++++-------- tutorials/mnist_dpsgd_tutorial.py | 9 +++++---- tutorials/mnist_dpsgd_tutorial_eager.py | 2 +- tutorials/mnist_dpsgd_tutorial_keras.py | 9 +++++---- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 2b27191..9a36767 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -27,9 +27,9 @@ from privacy.dp_query import gaussian_query def make_optimizer_class(cls): """Constructs a DP optimizer class from an existing one.""" if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): - parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__ + parent_code = tf.train.Optimizer.compute_gradients.__code__ child_code = cls.compute_gradients.__code__ - GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP # pylint: disable=invalid-name + GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name else: parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access @@ -213,12 +213,11 @@ def make_gaussian_optimizer_class(cls): return DPGaussianOptimizerClass -# Compatibility with tf 1 and 2 APIs -try: - AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer - AdamOptimizer = tf.compat.v1.train.AdamOptimizer - GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer -except: # pylint: disable=bare-except +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + AdagradOptimizer = tf.train.AdagradOptimizer + AdamOptimizer = tf.train.AdamOptimizer + GradientDescentOptimizer = tf.train.GradientDescentOptimizer +else: AdagradOptimizer = tf.optimizers.Adagrad AdamOptimizer = tf.optimizers.Adam GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py index 7f83672..f586445 100644 --- a/tutorials/mnist_dpsgd_tutorial.py +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -21,6 +21,8 @@ from __future__ import print_function from absl import app from absl import flags +from distutils.version import LooseVersion + import numpy as np import tensorflow as tf @@ -29,10 +31,9 @@ from privacy.analysis.rdp_accountant import compute_rdp_from_ledger from privacy.analysis.rdp_accountant import get_privacy_spent from privacy.optimizers import dp_optimizer -# Compatibility with tf 1 and 2 APIs -try: - GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer -except: # pylint: disable=bare-except +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + GradientDescentOptimizer = tf.train.GradientDescentOptimizer +else: GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name FLAGS = flags.FLAGS diff --git a/tutorials/mnist_dpsgd_tutorial_eager.py b/tutorials/mnist_dpsgd_tutorial_eager.py index 1a7b111..92d55b9 100644 --- a/tutorials/mnist_dpsgd_tutorial_eager.py +++ b/tutorials/mnist_dpsgd_tutorial_eager.py @@ -30,7 +30,7 @@ from privacy.dp_query.gaussian_query import GaussianAverageQuery from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): - GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer + GradientDescentOptimizer = tf.train.GradientDescentOptimizer tf.compat.v1.enable_eager_execution() else: GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name diff --git a/tutorials/mnist_dpsgd_tutorial_keras.py b/tutorials/mnist_dpsgd_tutorial_keras.py index f05c709..acf6dc6 100644 --- a/tutorials/mnist_dpsgd_tutorial_keras.py +++ b/tutorials/mnist_dpsgd_tutorial_keras.py @@ -20,6 +20,8 @@ from __future__ import print_function from absl import app from absl import flags +from distutils.version import LooseVersion + import numpy as np import tensorflow as tf @@ -28,10 +30,9 @@ from privacy.analysis.rdp_accountant import get_privacy_spent from privacy.dp_query.gaussian_query import GaussianAverageQuery from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer -# Compatibility with tf 1 and 2 APIs -try: - GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer -except: # pylint: disable=bare-except +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + GradientDescentOptimizer = tf.train.GradientDescentOptimizer +else: GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name flags.DEFINE_boolean(