Minor update to mnist_lr_tutorial.py to avoid (some) deprecated items.

PiperOrigin-RevId: 339327388
This commit is contained in:
Steve Chien 2020-10-27 14:15:49 -07:00 committed by A. Unique TensorFlower
parent 67f7f35383
commit f0daaf085f

View file

@ -30,8 +30,6 @@ import math
from absl import app from absl import app
from absl import flags from absl import flags
from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
@ -39,10 +37,7 @@ from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer from tensorflow_privacy.privacy.optimizers import dp_optimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@ -62,14 +57,11 @@ flags.DEFINE_float('data_l2_norm', 8, 'Bound on the L2 norm of normalized data')
def lr_model_fn(features, labels, mode, nclasses, dim): def lr_model_fn(features, labels, mode, nclasses, dim):
"""Model function for logistic regression.""" """Model function for logistic regression."""
input_layer = tf.reshape(features['x'], tuple([-1]) + dim) input_layer = tf.reshape(features['x'], tuple([-1]) + dim)
logits = tf.keras.layers.Dense(
logits = tf.layers.dense(
inputs=input_layer,
units=nclasses, units=nclasses,
kernel_regularizer=tf.contrib.layers.l2_regularizer( kernel_regularizer=tf.keras.regularizers.L2(l2=FLAGS.regularizer),
scale=FLAGS.regularizer), bias_regularizer=tf.keras.regularizers.L2(l2=FLAGS.regularizer)).apply(
bias_regularizer=tf.contrib.layers.l2_regularizer( input_layer)
scale=FLAGS.regularizer))
# Calculate loss as a vector (to support microbatches in DP-SGD). # Calculate loss as a vector (to support microbatches in DP-SGD).
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(