forked from 626_privacy/tensorflow_privacy
Modify loss passed to optimizer when dpsgd is False in MNIST tutorial
PiperOrigin-RevId: 229233829
This commit is contained in:
parent
89ca3f2a06
commit
4c1f3c07f4
1 changed files with 7 additions and 1 deletions
|
@ -79,11 +79,17 @@ def cnn_model_fn(features, labels, mode):
|
||||||
noise_multiplier=FLAGS.noise_multiplier,
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||||
num_microbatches=FLAGS.microbatches)
|
num_microbatches=FLAGS.microbatches)
|
||||||
|
opt_loss = vector_loss
|
||||||
else:
|
else:
|
||||||
optimizer = tf.train.GradientDescentOptimizer(
|
optimizer = tf.train.GradientDescentOptimizer(
|
||||||
learning_rate=FLAGS.learning_rate)
|
learning_rate=FLAGS.learning_rate)
|
||||||
|
opt_loss = scalar_loss
|
||||||
global_step = tf.train.get_global_step()
|
global_step = tf.train.get_global_step()
|
||||||
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
|
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
||||||
|
# In the following, we pass the mean of the loss (scalar_loss) rather than
|
||||||
|
# the vector_loss because tf.estimator requires a scalar loss. This is only
|
||||||
|
# used for evaluation and debugging by tf.estimator. The actual loss being
|
||||||
|
# minimized is opt_loss defined above and passed to optimizer.minimize().
|
||||||
return tf.estimator.EstimatorSpec(mode=mode,
|
return tf.estimator.EstimatorSpec(mode=mode,
|
||||||
loss=scalar_loss,
|
loss=scalar_loss,
|
||||||
train_op=train_op)
|
train_op=train_op)
|
||||||
|
|
Loading…
Reference in a new issue