minor fixes to improve tf 1 and 2 compatibility
PiperOrigin-RevId: 246008822
This commit is contained in:
parent
febafd830d
commit
c09ec4c22b
3 changed files with 7 additions and 9 deletions
|
@ -31,7 +31,7 @@ def make_optimizer_class(cls):
|
|||
child_code = cls.compute_gradients.__code__
|
||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
else:
|
||||
parent_code = tf.optimizers.Optimizer.compute_gradients.__code__
|
||||
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
GATE_OP = None # pylint: disable=invalid-name
|
||||
if child_code is not parent_code:
|
||||
|
|
|
@ -47,7 +47,7 @@ tf.flags.DEFINE_string('model_dir', None, 'Model directory')
|
|||
FLAGS = tf.flags.FLAGS
|
||||
|
||||
|
||||
class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook):
|
||||
class EpsilonPrintingTrainingHook(tf.train.SessionRunHook):
|
||||
"""Training hook to print current value of epsilon after an epoch."""
|
||||
|
||||
def __init__(self, ledger):
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
|||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from distutils.version import LooseVersion
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -26,13 +27,11 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
|
|||
from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||
|
||||
# Compatibility with tf 1 and 2 APIs
|
||||
try:
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
except: # pylint: disable=bare-except
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
||||
tf.enable_eager_execution()
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
||||
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||
'train with vanilla SGD.')
|
||||
|
@ -133,8 +132,7 @@ def main(_):
|
|||
else:
|
||||
grads_and_vars = opt.compute_gradients(loss_fn, var_list)
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
opt.apply_gradients(grads_and_vars, global_step=global_step)
|
||||
opt.apply_gradients(grads_and_vars)
|
||||
|
||||
# Evaluate the model and print results
|
||||
for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):
|
||||
|
|
Loading…
Reference in a new issue