Minor fix regarding tf.compat.v1 in mnist_dpsgd_tutorial_vectorized.

PiperOrigin-RevId: 303379200
This commit is contained in:
Steve Chien 2020-03-27 12:07:36 -07:00 committed by A. Unique TensorFlower
parent 0c2747462f
commit 7647c54a27

View file

@ -106,7 +106,7 @@ def cnn_model_fn(features, labels, mode):
else:
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
opt_loss = scalar_loss
global_step = tf.compat.get_global_step()
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
# In the following, we pass the mean of the loss (scalar_loss) rather than
# the vector_loss because tf.estimator requires a scalar loss. This is only