fix imports for v1 and make the versioning more explicit through LooseVersion

PiperOrigin-RevId: 249732562
This commit is contained in:
Nicolas Papernot 2019-05-23 15:56:48 -07:00 committed by A. Unique TensorFlower
parent 0efb23afcb
commit a06bc6c99b
4 changed files with 18 additions and 17 deletions

View file

@ -27,9 +27,9 @@ from privacy.dp_query import gaussian_query
def make_optimizer_class(cls): def make_optimizer_class(cls):
"""Constructs a DP optimizer class from an existing one.""" """Constructs a DP optimizer class from an existing one."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__ parent_code = tf.train.Optimizer.compute_gradients.__code__
child_code = cls.compute_gradients.__code__ child_code = cls.compute_gradients.__code__
GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP # pylint: disable=invalid-name GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
else: else:
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
@ -213,12 +213,11 @@ def make_gaussian_optimizer_class(cls):
return DPGaussianOptimizerClass return DPGaussianOptimizerClass
# Compatibility with tf 1 and 2 APIs if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
try: AdagradOptimizer = tf.train.AdagradOptimizer
AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer AdamOptimizer = tf.train.AdamOptimizer
AdamOptimizer = tf.compat.v1.train.AdamOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer else:
except: # pylint: disable=bare-except
AdagradOptimizer = tf.optimizers.Adagrad AdagradOptimizer = tf.optimizers.Adagrad
AdamOptimizer = tf.optimizers.Adam AdamOptimizer = tf.optimizers.Adam
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name

View file

@ -21,6 +21,8 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -29,10 +31,9 @@ from privacy.analysis.rdp_accountant import compute_rdp_from_ledger
from privacy.analysis.rdp_accountant import get_privacy_spent from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer from privacy.optimizers import dp_optimizer
# Compatibility with tf 1 and 2 APIs if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
try: GradientDescentOptimizer = tf.train.GradientDescentOptimizer
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer else:
except: # pylint: disable=bare-except
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
FLAGS = flags.FLAGS FLAGS = flags.FLAGS

View file

@ -30,7 +30,7 @@ from privacy.dp_query.gaussian_query import GaussianAverageQuery
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
tf.compat.v1.enable_eager_execution() tf.compat.v1.enable_eager_execution()
else: else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name

View file

@ -20,6 +20,8 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -28,10 +30,9 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.dp_query.gaussian_query import GaussianAverageQuery from privacy.dp_query.gaussian_query import GaussianAverageQuery
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
# Compatibility with tf 1 and 2 APIs if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
try: GradientDescentOptimizer = tf.train.GradientDescentOptimizer
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer else:
except: # pylint: disable=bare-except
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
flags.DEFINE_boolean( flags.DEFINE_boolean(