diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py index 961b334..9afe8df 100644 --- a/tutorials/mnist_dpsgd_tutorial.py +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -79,11 +79,17 @@ def cnn_model_fn(features, labels, mode): noise_multiplier=FLAGS.noise_multiplier, l2_norm_clip=FLAGS.l2_norm_clip, num_microbatches=FLAGS.microbatches) + opt_loss = vector_loss else: optimizer = tf.train.GradientDescentOptimizer( learning_rate=FLAGS.learning_rate) + opt_loss = scalar_loss global_step = tf.train.get_global_step() - train_op = optimizer.minimize(loss=vector_loss, global_step=global_step) + train_op = optimizer.minimize(loss=opt_loss, global_step=global_step) + # In the following, we pass the mean of the loss (scalar_loss) rather than + # the vector_loss because tf.estimator requires a scalar loss. This is only + # used for evaluation and debugging by tf.estimator. The actual loss being + # minimized is opt_loss defined above and passed to optimizer.minimize(). return tf.estimator.EstimatorSpec(mode=mode, loss=scalar_loss, train_op=train_op)