minor fixes to improve tf 1 and 2 compatibility

PiperOrigin-RevId: 246008822
This commit is contained in:
Nicolas Papernot 2019-04-30 13:22:35 -07:00 committed by A. Unique TensorFlower
parent febafd830d
commit c09ec4c22b
3 changed files with 7 additions and 9 deletions

View file

@ -31,7 +31,7 @@ def make_optimizer_class(cls):
child_code = cls.compute_gradients.__code__ child_code = cls.compute_gradients.__code__
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
else: else:
parent_code = tf.optimizers.Optimizer.compute_gradients.__code__ parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
GATE_OP = None # pylint: disable=invalid-name GATE_OP = None # pylint: disable=invalid-name
if child_code is not parent_code: if child_code is not parent_code:

View file

@ -47,7 +47,7 @@ tf.flags.DEFINE_string('model_dir', None, 'Model directory')
FLAGS = tf.flags.FLAGS FLAGS = tf.flags.FLAGS
class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook): class EpsilonPrintingTrainingHook(tf.train.SessionRunHook):
"""Training hook to print current value of epsilon after an epoch.""" """Training hook to print current value of epsilon after an epoch."""
def __init__(self, ledger): def __init__(self, ledger):

View file

@ -18,6 +18,7 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -26,13 +27,11 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.dp_query.gaussian_query import GaussianAverageQuery from privacy.dp_query.gaussian_query import GaussianAverageQuery
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
# Compatibility with tf 1 and 2 APIs if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
try:
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
except: # pylint: disable=bare-except
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
tf.enable_eager_execution() tf.enable_eager_execution()
else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, ' flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
'train with vanilla SGD.') 'train with vanilla SGD.')
@ -133,8 +132,7 @@ def main(_):
else: else:
grads_and_vars = opt.compute_gradients(loss_fn, var_list) grads_and_vars = opt.compute_gradients(loss_fn, var_list)
global_step = tf.train.get_or_create_global_step() opt.apply_gradients(grads_and_vars)
opt.apply_gradients(grads_and_vars, global_step=global_step)
# Evaluate the model and print results # Evaluate the model and print results
for (_, (images, labels)) in enumerate(eval_dataset.take(-1)): for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):