forked from 626_privacy/tensorflow_privacy
fix imports for v1 and make the versioning more explicit through LooseVersion
PiperOrigin-RevId: 249732562
This commit is contained in:
parent
0efb23afcb
commit
a06bc6c99b
4 changed files with 18 additions and 17 deletions
|
@ -27,9 +27,9 @@ from privacy.dp_query import gaussian_query
|
|||
def make_optimizer_class(cls):
|
||||
"""Constructs a DP optimizer class from an existing one."""
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__
|
||||
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
||||
child_code = cls.compute_gradients.__code__
|
||||
GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
else:
|
||||
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
|
@ -213,12 +213,11 @@ def make_gaussian_optimizer_class(cls):
|
|||
|
||||
return DPGaussianOptimizerClass
|
||||
|
||||
# Compatibility with tf 1 and 2 APIs
|
||||
try:
|
||||
AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.compat.v1.train.AdamOptimizer
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
except: # pylint: disable=bare-except
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
AdagradOptimizer = tf.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.train.AdamOptimizer
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
AdagradOptimizer = tf.optimizers.Adagrad
|
||||
AdamOptimizer = tf.optimizers.Adam
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
|
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
|||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -29,10 +31,9 @@ from privacy.analysis.rdp_accountant import compute_rdp_from_ledger
|
|||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from privacy.optimizers import dp_optimizer
|
||||
|
||||
# Compatibility with tf 1 and 2 APIs
|
||||
try:
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
except: # pylint: disable=bare-except
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
|
|
@ -30,7 +30,7 @@ from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
|||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
|
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
|||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -28,10 +30,9 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
|
|||
from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||
|
||||
# Compatibility with tf 1 and 2 APIs
|
||||
try:
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
except: # pylint: disable=bare-except
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
|
|
Loading…
Reference in a new issue