forked from 626_privacy/tensorflow_privacy
minor fixes to improve tf 1 and 2 compatibility
PiperOrigin-RevId: 246008822
This commit is contained in:
parent
febafd830d
commit
c09ec4c22b
3 changed files with 7 additions and 9 deletions
|
@ -31,7 +31,7 @@ def make_optimizer_class(cls):
|
||||||
child_code = cls.compute_gradients.__code__
|
child_code = cls.compute_gradients.__code__
|
||||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||||
else:
|
else:
|
||||||
parent_code = tf.optimizers.Optimizer.compute_gradients.__code__
|
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
||||||
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
||||||
GATE_OP = None # pylint: disable=invalid-name
|
GATE_OP = None # pylint: disable=invalid-name
|
||||||
if child_code is not parent_code:
|
if child_code is not parent_code:
|
||||||
|
|
|
@ -47,7 +47,7 @@ tf.flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||||
FLAGS = tf.flags.FLAGS
|
FLAGS = tf.flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook):
|
class EpsilonPrintingTrainingHook(tf.train.SessionRunHook):
|
||||||
"""Training hook to print current value of epsilon after an epoch."""
|
"""Training hook to print current value of epsilon after an epoch."""
|
||||||
|
|
||||||
def __init__(self, ledger):
|
def __init__(self, ledger):
|
||||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
from distutils.version import LooseVersion
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -26,13 +27,11 @@ from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
from privacy.dp_query.gaussian_query import GaussianAverageQuery
|
||||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||||
|
|
||||||
# Compatibility with tf 1 and 2 APIs
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
try:
|
|
||||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
except: # pylint: disable=bare-except
|
|
||||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
|
else:
|
||||||
|
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||||
|
|
||||||
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
|
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||||
'train with vanilla SGD.')
|
'train with vanilla SGD.')
|
||||||
|
@ -133,8 +132,7 @@ def main(_):
|
||||||
else:
|
else:
|
||||||
grads_and_vars = opt.compute_gradients(loss_fn, var_list)
|
grads_and_vars = opt.compute_gradients(loss_fn, var_list)
|
||||||
|
|
||||||
global_step = tf.train.get_or_create_global_step()
|
opt.apply_gradients(grads_and_vars)
|
||||||
opt.apply_gradients(grads_and_vars, global_step=global_step)
|
|
||||||
|
|
||||||
# Evaluate the model and print results
|
# Evaluate the model and print results
|
||||||
for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):
|
for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):
|
||||||
|
|
Loading…
Reference in a new issue