fix keras typo
PiperOrigin-RevId: 258434656
This commit is contained in:
parent
973a1759aa
commit
bb7956ed7e
1 changed files with 4 additions and 5 deletions
|
@ -41,10 +41,10 @@ flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
|
||||||
flags.DEFINE_float('noise_multiplier', 1.1,
|
flags.DEFINE_float('noise_multiplier', 1.1,
|
||||||
'Ratio of the standard deviation to the clipping norm')
|
'Ratio of the standard deviation to the clipping norm')
|
||||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||||
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||||
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
'microbatches', 256, 'Number of microbatches '
|
'microbatches', 250, 'Number of microbatches '
|
||||||
'(must evenly divide batch_size)')
|
'(must evenly divide batch_size)')
|
||||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||||
|
|
||||||
|
@ -121,9 +121,8 @@ def main(unused_argv):
|
||||||
optimizer = DPGradientDescentGaussianOptimizer(
|
optimizer = DPGradientDescentGaussianOptimizer(
|
||||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||||
noise_multiplier=FLAGS.noise_multiplier,
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
num_microbatches=FLAGS.num_microbatches,
|
num_microbatches=FLAGS.microbatches,
|
||||||
learning_rate=FLAGS.learning_rate,
|
learning_rate=FLAGS.learning_rate)
|
||||||
unroll_microbatches=True)
|
|
||||||
# Compute vector of per-example loss rather than its mean over a minibatch.
|
# Compute vector of per-example loss rather than its mean over a minibatch.
|
||||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.losses.Reduction.NONE)
|
from_logits=True, reduction=tf.losses.Reduction.NONE)
|
||||||
|
|
Loading…
Reference in a new issue