diff --git a/tutorials/mnist_dpsgd_tutorial_keras_model.py b/tutorials/mnist_dpsgd_tutorial_keras_model.py index abd3683..b000c03 100644 --- a/tutorials/mnist_dpsgd_tutorial_keras_model.py +++ b/tutorials/mnist_dpsgd_tutorial_keras_model.py @@ -113,7 +113,9 @@ def main(unused_argv): model = DPSequential( l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, - layers=layers) + num_microbatches=FLAGS.microbatches, + layers=layers, + ) else: model = tf.keras.Sequential(layers=layers)