From c09ec4c22bbb4f590c99e47d190c126fbc70ff28 Mon Sep 17 00:00:00 2001 From: Nicolas Papernot Date: Tue, 30 Apr 2019 13:22:35 -0700 Subject: [PATCH] minor fixes to improve tf 1 and 2 compatibility PiperOrigin-RevId: 246008822 --- privacy/optimizers/dp_optimizer.py | 2 +- tutorials/mnist_dpsgd_tutorial.py | 2 +- tutorials/mnist_dpsgd_tutorial_eager.py | 12 +++++------- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/privacy/optimizers/dp_optimizer.py b/privacy/optimizers/dp_optimizer.py index 64a16c1..19d49ff 100644 --- a/privacy/optimizers/dp_optimizer.py +++ b/privacy/optimizers/dp_optimizer.py @@ -31,7 +31,7 @@ def make_optimizer_class(cls): child_code = cls.compute_gradients.__code__ GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name else: - parent_code = tf.optimizers.Optimizer.compute_gradients.__code__ + parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access GATE_OP = None # pylint: disable=invalid-name if child_code is not parent_code: diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py index 7ae12c1..951b1c7 100644 --- a/tutorials/mnist_dpsgd_tutorial.py +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -47,7 +47,7 @@ tf.flags.DEFINE_string('model_dir', None, 'Model directory') FLAGS = tf.flags.FLAGS -class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook): +class EpsilonPrintingTrainingHook(tf.train.SessionRunHook): """Training hook to print current value of epsilon after an epoch.""" def __init__(self, ledger): diff --git a/tutorials/mnist_dpsgd_tutorial_eager.py b/tutorials/mnist_dpsgd_tutorial_eager.py index 356b276..bfd23be 100644 --- a/tutorials/mnist_dpsgd_tutorial_eager.py +++ b/tutorials/mnist_dpsgd_tutorial_eager.py @@ -18,6 +18,7 @@ from __future__ import print_function from absl import app from absl import flags +from distutils.version import LooseVersion import numpy as np import tensorflow as tf @@ -26,14 +27,12 @@ from privacy.analysis.rdp_accountant import get_privacy_spent from privacy.dp_query.gaussian_query import GaussianAverageQuery from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer -# Compatibility with tf 1 and 2 APIs -try: +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): GradientDescentOptimizer = tf.train.GradientDescentOptimizer -except: # pylint: disable=bare-except + tf.enable_eager_execution() +else: GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name -tf.enable_eager_execution() - flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, ' 'train with vanilla SGD.') flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training') @@ -133,8 +132,7 @@ def main(_): else: grads_and_vars = opt.compute_gradients(loss_fn, var_list) - global_step = tf.train.get_or_create_global_step() - opt.apply_gradients(grads_and_vars, global_step=global_step) + opt.apply_gradients(grads_and_vars) # Evaluate the model and print results for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):