minor fixes to improve tf 1 and 2 compatibility

PiperOrigin-RevId: 246008822
This commit is contained in:
Nicolas Papernot 2019-04-30 13:22:35 -07:00 committed by A. Unique TensorFlower
parent febafd830d
commit c09ec4c22b
3 changed files with 7 additions and 9 deletions

View file

@ -31,7 +31,7 @@ def make_optimizer_class(cls):
child_code = cls.compute_gradients.__code__
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
else:
parent_code = tf.optimizers.Optimizer.compute_gradients.__code__
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
GATE_OP = None # pylint: disable=invalid-name
if child_code is not parent_code:

View file

@ -47,7 +47,7 @@ tf.flags.DEFINE_string('model_dir', None, 'Model directory')
FLAGS = tf.flags.FLAGS
class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook):
class EpsilonPrintingTrainingHook(tf.train.SessionRunHook):
"""Training hook to print current value of epsilon after an epoch."""
def __init__(self, ledger):

View file

@ -18,6 +18,7 @@ from __future__ import print_function
from absl import app
from absl import flags
from distutils.version import LooseVersion
import numpy as np
import tensorflow as tf
@ -26,14 +27,12 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.dp_query.gaussian_query import GaussianAverageQuery
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
# Compatibility with tf 1 and 2 APIs
try:
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
except: # pylint: disable=bare-except
tf.enable_eager_execution()
else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
tf.enable_eager_execution()
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
'train with vanilla SGD.')
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
@ -133,8 +132,7 @@ def main(_):
else:
grads_and_vars = opt.compute_gradients(loss_fn, var_list)
global_step = tf.train.get_or_create_global_step()
opt.apply_gradients(grads_and_vars, global_step=global_step)
opt.apply_gradients(grads_and_vars)
# Evaluate the model and print results
for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):