From 45bcb3a0e4d5ed7759af348725e7d031e0e22b96 Mon Sep 17 00:00:00 2001 From: Ilya Mironov Date: Mon, 24 Jun 2019 12:49:48 -0700 Subject: [PATCH] Adding privacy analysis to the Logistic Regression for MNIST tutorial. PiperOrigin-RevId: 254815428 --- tutorials/README.md | 4 + ...gression_mnist.py => mnist_lr_tutorial.py} | 145 +++++++++++------- 2 files changed, 92 insertions(+), 57 deletions(-) rename tutorials/{logistic_regression_mnist.py => mnist_lr_tutorial.py} (59%) diff --git a/tutorials/README.md b/tutorials/README.md index 94b5cef..d3f60d3 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -20,6 +20,10 @@ Here is a list of all the tutorials included: * `mnist_dpsgd_tutorial_keras.py`: learn a convolutional neural network on MNIST with differential privacy using tf.Keras. +* `mnist_lr_tutorial.py`: learn a differentially private logistic regression + model on MNIST. The model illustrates application of the + "amplification-by-iteration" analysis (https://arxiv.org/abs/1808.06651). + The rest of this README describes the different parameters used to configure DP-SGD as well as expected outputs for the `mnist_dpsgd_tutorial.py` tutorial. diff --git a/tutorials/logistic_regression_mnist.py b/tutorials/mnist_lr_tutorial.py similarity index 59% rename from tutorials/logistic_regression_mnist.py rename to tutorials/mnist_lr_tutorial.py index 694ee7d..62f446d 100644 --- a/tutorials/logistic_regression_mnist.py +++ b/tutorials/mnist_lr_tutorial.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """DP Logistic Regression on MNIST. DP Logistic Regression on MNIST with support for privacy-by-iteration analysis. -Feldman, Vitaly, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta. +Vitaly Feldman, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta. "Privacy amplification by iteration." In 2018 IEEE 59th Annual Symposium on Foundations of Computer Science (FOCS), pp. 521-532. IEEE, 2018. @@ -36,6 +35,8 @@ from distutils.version import LooseVersion import numpy as np import tensorflow as tf +from privacy.analysis.rdp_accountant import compute_rdp +from privacy.analysis.rdp_accountant import get_privacy_spent from privacy.optimizers import dp_optimizer if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): @@ -45,32 +46,30 @@ else: FLAGS = flags.FLAGS -flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, ' - 'train with vanilla SGD.') +flags.DEFINE_boolean( + 'dpsgd', True, 'If True, train with DP-SGD. If False, ' + 'train with vanilla SGD.') flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training') -flags.DEFINE_float('noise_multiplier', 0.02, +flags.DEFINE_float('noise_multiplier', 0.05, 'Ratio of the standard deviation to the clipping norm') -flags.DEFINE_integer('batch_size', 1, 'Batch size') +flags.DEFINE_integer('batch_size', 5, 'Batch size') flags.DEFINE_integer('epochs', 5, 'Number of epochs') -flags.DEFINE_integer('microbatches', 1, 'Number of microbatches ' - '(must evenly divide batch_size)') flags.DEFINE_float('regularizer', 0, 'L2 regularizer coefficient') flags.DEFINE_string('model_dir', None, 'Model directory') -flags.DEFINE_float('data_l2_norm', 8, - 'Bound on the L2 norm of normalized data.') +flags.DEFINE_float('data_l2_norm', 8, 'Bound on the L2 norm of normalized data') def lr_model_fn(features, labels, mode, nclasses, dim): """Model function for logistic regression.""" input_layer = tf.reshape(features['x'], tuple([-1]) + dim) - logits = tf.layers.dense(inputs=input_layer, - units=nclasses, - kernel_regularizer=tf.contrib.layers.l2_regularizer( - scale=FLAGS.regularizer), - bias_regularizer=tf.contrib.layers.l2_regularizer( - scale=FLAGS.regularizer) - ) + logits = tf.layers.dense( + inputs=input_layer, + units=nclasses, + kernel_regularizer=tf.contrib.layers.l2_regularizer( + scale=FLAGS.regularizer), + bias_regularizer=tf.contrib.layers.l2_regularizer( + scale=FLAGS.regularizer)) # Calculate loss as a vector (to support microbatches in DP-SGD). vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -80,18 +79,15 @@ def lr_model_fn(features, labels, mode, nclasses, dim): # Configure the training op (for TRAIN mode). if mode == tf.estimator.ModeKeys.TRAIN: - if FLAGS.dpsgd: - # Use DP version of GradientDescentOptimizer. Other optimizers are - # available in dp_optimizer. Most optimizers inheriting from - # tf.train.Optimizer should be wrappable in differentially private - # counterparts by calling dp_optimizer.optimizer_from_args(). # The loss function is L-Lipschitz with L = sqrt(2*(||x||^2 + 1)) where # ||x|| is the norm of the data. + # We don't use microbatches (thus speeding up computation), since no + # clipping is necessary due to data normalization. optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer( - l2_norm_clip=math.sqrt(2*(FLAGS.data_l2_norm**2 + 1)), + l2_norm_clip=math.sqrt(2 * (FLAGS.data_l2_norm**2 + 1)), noise_multiplier=FLAGS.noise_multiplier, - num_microbatches=FLAGS.microbatches, + num_microbatches=1, learning_rate=FLAGS.learning_rate) opt_loss = vector_loss else: @@ -103,21 +99,18 @@ def lr_model_fn(features, labels, mode, nclasses, dim): # the vector_loss because tf.estimator requires a scalar loss. This is only # used for evaluation and debugging by tf.estimator. The actual loss being # minimized is opt_loss defined above and passed to optimizer.minimize(). - return tf.estimator.EstimatorSpec(mode=mode, - loss=scalar_loss, - train_op=train_op) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) # Add evaluation metrics (for EVAL mode). elif mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = { 'accuracy': tf.metrics.accuracy( - labels=labels, - predictions=tf.argmax(input=logits, axis=1)) + labels=labels, predictions=tf.argmax(input=logits, axis=1)) } - return tf.estimator.EstimatorSpec(mode=mode, - loss=scalar_loss, - eval_metric_ops=eval_metric_ops) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops) def normalize_data(data, data_l2_norm): @@ -146,7 +139,7 @@ def load_mnist(data_l2_norm=float('inf')): train_data = train_data.reshape(train_data.shape[0], -1) test_data = test_data.reshape(test_data.shape[0], -1) - idx = np.random.permutation(len(train_data)) # shuffle data once + idx = np.random.permutation(len(train_data)) # shuffle data once train_data = train_data[idx] train_labels = train_labels[idx] @@ -159,14 +152,50 @@ def load_mnist(data_l2_norm=float('inf')): return train_data, train_labels, test_data, test_labels +def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier): + """Tabulating position-dependent privacy guarantees.""" + if noise_multiplier == 0: + print('No differential privacy (additive noise is 0).') + return + + print('In the conditions of Theorem 34 (https://arxiv.org/abs/1808.06651) ' + 'the training procedure results in the following privacy guarantees.') + + print('Out of the total of {} samples:'.format(samples)) + + steps_per_epoch = samples // batch_size + orders = np.concatenate( + [np.linspace(2, 20, num=181), + np.linspace(20, 100, num=81)]) + delta = 1e-5 + for p in (.5, .9, .99): + steps = math.ceil(steps_per_epoch * p) # Steps in the last epoch. + coef = 2 * (noise_multiplier * batch_size)**-2 * ( + # Accounting for privacy loss + (epochs - 1) / steps_per_epoch + # ... from all-but-last epochs + 1 / (steps_per_epoch - steps + 1)) # ... due to the last epoch + # Using RDP accountant to compute eps. Doing computation analytically is + # an option. + rdp = [order * coef for order in orders] + eps, _, _ = get_privacy_spent(orders, rdp, target_delta=delta) + print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format( + p * 100, eps, delta)) + + # Compute privacy guarantees for the Sampled Gaussian Mechanism. + rdp_sgm = compute_rdp(batch_size / samples, noise_multiplier, + epochs * steps_per_epoch, orders) + eps_sgm, _, _ = get_privacy_spent(orders, rdp_sgm, target_delta=delta) + print('By comparison, DP-SGD analysis for training done with the same ' + 'parameters and random shuffling in each epoch guarantees ' + '({:.2f}, {})-DP for all samples.'.format(eps_sgm, delta)) + + def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) - if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: - raise ValueError('Number of microbatches should divide evenly batch_size') if FLAGS.data_l2_norm <= 0: - raise ValueError('FLAGS.data_l2_norm needs to be positive.') - if FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2: - raise ValueError('The amplification by iteration analysis requires' + raise ValueError('data_l2_norm must be positive.') + if FLAGS.dpsgd and FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2: + raise ValueError('The amplification-by-iteration analysis requires' 'learning_rate <= 2 / beta, where beta is the smoothness' 'of the loss function and is upper bounded by ||x||^2 / 4' 'with ||x|| being the largest L2 norm of the samples.') @@ -178,15 +207,12 @@ def main(unused_argv): train_data, train_labels, test_data, test_labels = load_mnist( data_l2_norm=FLAGS.data_l2_norm) - # Instantiate the tf.Estimator. + # Instantiate tf.Estimator. # pylint: disable=g-long-lambda - model_fn = lambda features, labels, mode: lr_model_fn(features, labels, mode, - nclasses=10, - dim=train_data.shape[1:] - ) + model_fn = lambda features, labels, mode: lr_model_fn( + features, labels, mode, nclasses=10, dim=train_data.shape[1:]) mnist_classifier = tf.estimator.Estimator( - model_fn=model_fn, - model_dir=FLAGS.model_dir) + model_fn=model_fn, model_dir=FLAGS.model_dir) # Create tf.Estimator input functions for the training and test data. # To analyze the per-user privacy loss, we keep the same orders of samples in @@ -198,22 +224,27 @@ def main(unused_argv): num_epochs=FLAGS.epochs, shuffle=False) eval_input_fn = tf.estimator.inputs.numpy_input_fn( - x={'x': test_data}, - y=test_labels, - num_epochs=1, - shuffle=False) + x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False) - # Train the model - steps_per_epoch = train_data.shape[0] // FLAGS.batch_size - mnist_classifier.train(input_fn=train_input_fn, - steps=steps_per_epoch * FLAGS.epochs) + # Train the model. + num_samples = train_data.shape[0] + steps_per_epoch = num_samples // FLAGS.batch_size - # Evaluate the model and print results + mnist_classifier.train( + input_fn=train_input_fn, steps=steps_per_epoch * FLAGS.epochs) + + # Evaluate the model and print results. eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) - test_accuracy = eval_results['accuracy'] - print('Test accuracy after %d epochs is: %.3f' % (FLAGS.epochs, - test_accuracy)) + print('Test accuracy after {} epochs is: {:.2f}'.format( + FLAGS.epochs, eval_results['accuracy'])) + if FLAGS.dpsgd: + print_privacy_guarantees( + epochs=FLAGS.epochs, + batch_size=FLAGS.batch_size, + samples=num_samples, + noise_multiplier=FLAGS.noise_multiplier, + ) if __name__ == '__main__': app.run(main)