diff --git a/tutorials/mnist_dpsgd_tutorial_keras.py b/tutorials/mnist_dpsgd_tutorial_keras.py index 927cb07..7e78a6d 100644 --- a/tutorials/mnist_dpsgd_tutorial_keras.py +++ b/tutorials/mnist_dpsgd_tutorial_keras.py @@ -22,19 +22,17 @@ from absl import flags from absl import logging import numpy as np -import tensorflow.compat.v1 as tf +import tensorflow as tf from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent -from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer - -GradientDescentOptimizer = tf.train.GradientDescentOptimizer +from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer flags.DEFINE_boolean( 'dpsgd', True, 'If True, train with DP-SGD. If False, ' 'train with vanilla SGD.') flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training') -flags.DEFINE_float('noise_multiplier', 1.1, +flags.DEFINE_float('noise_multiplier', 0.1, 'Ratio of the standard deviation to the clipping norm') flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') flags.DEFINE_integer('batch_size', 250, 'Batch size') @@ -70,8 +68,8 @@ def load_mnist(): train_data = np.array(train_data, dtype=np.float32) / 255 test_data = np.array(test_data, dtype=np.float32) / 255 - train_data = train_data.reshape(train_data.shape[0], 28, 28, 1) - test_data = test_data.reshape(test_data.shape[0], 28, 28, 1) + train_data = train_data.reshape((train_data.shape[0], 28, 28, 1)) + test_data = test_data.reshape((test_data.shape[0], 28, 28, 1)) train_labels = np.array(train_labels, dtype=np.int32) test_labels = np.array(test_labels, dtype=np.int32) @@ -114,7 +112,7 @@ def main(unused_argv): ]) if FLAGS.dpsgd: - optimizer = DPGradientDescentGaussianOptimizer( + optimizer = DPKerasSGDOptimizer( l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, num_microbatches=FLAGS.microbatches, @@ -123,7 +121,7 @@ def main(unused_argv): loss = tf.keras.losses.CategoricalCrossentropy( from_logits=True, reduction=tf.losses.Reduction.NONE) else: - optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) + optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.learning_rate) loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) # Compile model with Keras