fix keras typo
PiperOrigin-RevId: 258434656
This commit is contained in:
parent
973a1759aa
commit
bb7956ed7e
1 changed files with 4 additions and 5 deletions
|
@ -41,10 +41,10 @@ flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
|
|||
flags.DEFINE_float('noise_multiplier', 1.1,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||
flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||
flags.DEFINE_integer(
|
||||
'microbatches', 256, 'Number of microbatches '
|
||||
'microbatches', 250, 'Number of microbatches '
|
||||
'(must evenly divide batch_size)')
|
||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
|
||||
|
@ -121,9 +121,8 @@ def main(unused_argv):
|
|||
optimizer = DPGradientDescentGaussianOptimizer(
|
||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
num_microbatches=FLAGS.num_microbatches,
|
||||
learning_rate=FLAGS.learning_rate,
|
||||
unroll_microbatches=True)
|
||||
num_microbatches=FLAGS.microbatches,
|
||||
learning_rate=FLAGS.learning_rate)
|
||||
# Compute vector of per-example loss rather than its mean over a minibatch.
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.losses.Reduction.NONE)
|
||||
|
|
Loading…
Reference in a new issue