fix imports for v1 and make the versioning more explicit through LooseVersion
PiperOrigin-RevId: 249732562
This commit is contained in:
parent
0efb23afcb
commit
a06bc6c99b
4 changed files with 18 additions and 17 deletions
|
@ -27,9 +27,9 @@ from privacy.dp_query import gaussian_query
|
||||||
def make_optimizer_class(cls):
|
def make_optimizer_class(cls):
|
||||||
"""Constructs a DP optimizer class from an existing one."""
|
"""Constructs a DP optimizer class from an existing one."""
|
||||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__
|
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
||||||
child_code = cls.compute_gradients.__code__
|
child_code = cls.compute_gradients.__code__
|
||||||
GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||||
else:
|
else:
|
||||||
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
||||||
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
||||||
|
@ -213,12 +213,11 @@ def make_gaussian_optimizer_class(cls):
|
||||||
|
|
||||||
return DPGaussianOptimizerClass
|
return DPGaussianOptimizerClass
|
||||||
|
|
||||||
# Compatibility with tf 1 and 2 APIs
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
try:
|
AdagradOptimizer = tf.train.AdagradOptimizer
|
||||||
AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer
|
AdamOptimizer = tf.train.AdamOptimizer
|
||||||
AdamOptimizer = tf.compat.v1.train.AdamOptimizer
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
else:
|
||||||
except: # pylint: disable=bare-except
|
|
||||||
AdagradOptimizer = tf.optimizers.Adagrad
|
AdagradOptimizer = tf.optimizers.Adagrad
|
||||||
AdamOptimizer = tf.optimizers.Adam
|
AdamOptimizer = tf.optimizers.Adam
|
||||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||||
|
|
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -29,10 +31,9 @@ from privacy.analysis.rdp_accountant import compute_rdp_from_ledger
|
||||||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
from privacy.optimizers import dp_optimizer
|
from privacy.optimizers import dp_optimizer
|
||||||
|
|
||||||
# Compatibility with tf 1 and 2 APIs
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
try:
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
else:
|
||||||
except: # pylint: disable=bare-except
|
|
||||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
|
@ -30,7 +30,7 @@ from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
||||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||||
|
|
||||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
tf.compat.v1.enable_eager_execution()
|
tf.compat.v1.enable_eager_execution()
|
||||||
else:
|
else:
|
||||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||||
|
|
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -28,10 +30,9 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
||||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||||
|
|
||||||
# Compatibility with tf 1 and 2 APIs
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
try:
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
else:
|
||||||
except: # pylint: disable=bare-except
|
|
||||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean(
|
||||||
|
|
Loading…
Reference in a new issue