From 10335f61775faabe5e931f30b0ecf91f0719b727 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Thu, 7 May 2020 12:05:31 -0700 Subject: [PATCH] Refactor MNIST tutorials and create new TPU tutorial: 1. Move common code to new file mnist_dpsgd_tutorial_common.py. 2. Move epsilon computation function out of binary into its own library. 3. Create new TPU tutorial. PiperOrigin-RevId: 310409308 --- .../analysis/compute_dp_sgd_privacy.py | 42 +---- .../analysis/compute_dp_sgd_privacy_lib.py | 66 +++++++ .../analysis/compute_dp_sgd_privacy_test.py | 4 +- tutorials/mnist_dpsgd_tutorial.py | 153 ++++------------ tutorials/mnist_dpsgd_tutorial_common.py | 75 ++++++++ tutorials/mnist_dpsgd_tutorial_tpu.py | 167 ++++++++++++++++++ 6 files changed, 351 insertions(+), 156 deletions(-) create mode 100644 tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py create mode 100644 tutorials/mnist_dpsgd_tutorial_common.py create mode 100644 tutorials/mnist_dpsgd_tutorial_tpu.py diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py index c1bf270..e4a0efb 100644 --- a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py @@ -32,18 +32,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math import sys from absl import app from absl import flags +from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy + # Opting out of loading all sibling packages and their dependencies. sys.skip_tf_privacy_import = True -from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top -from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent - FLAGS = flags.FLAGS flags.DEFINE_integer('N', None, 'Total number of examples') @@ -53,42 +51,6 @@ flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)') flags.DEFINE_float('delta', 1e-6, 'Target delta') -def apply_dp_sgd_analysis(q, sigma, steps, orders, delta): - """Compute and print results of DP-SGD analysis.""" - - # compute_rdp requires that sigma be the ratio of the standard deviation of - # the Gaussian noise to the l2-sensitivity of the function to which it is - # added. Hence, sigma here corresponds to the `noise_multiplier` parameter - # in the DP-SGD implementation found in privacy.optimizers.dp_optimizer - rdp = compute_rdp(q, sigma, steps, orders) - - eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta) - - print('DP-SGD with sampling rate = {:.3g}% and noise_multiplier = {} iterated' - ' over {} steps satisfies'.format(100 * q, sigma, steps), end=' ') - print('differential privacy with eps = {:.3g} and delta = {}.'.format( - eps, delta)) - print('The optimal RDP order is {}.'.format(opt_order)) - - if opt_order == max(orders) or opt_order == min(orders): - print('The privacy estimate is likely to be improved by expanding ' - 'the set of orders.') - - return eps, opt_order - - -def compute_dp_sgd_privacy(n, batch_size, noise_multiplier, epochs, delta): - """Compute epsilon based on the given hyperparameters.""" - q = batch_size / n # q - the sampling ratio. - if q > 1: - raise app.UsageError('n must be larger than the batch size.') - orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + - list(range(5, 64)) + [128, 256, 512]) - steps = int(math.ceil(epochs * n / batch_size)) - - return apply_dp_sgd_analysis(q, noise_multiplier, steps, orders, delta) - - def main(argv): del argv # argv is not used. diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py new file mode 100644 index 0000000..c8c9821 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py @@ -0,0 +1,66 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for computing privacy values for DP-SGD.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import sys + +from absl import app + +# Opting out of loading all sibling packages and their dependencies. +sys.skip_tf_privacy_import = True + +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent + + +def apply_dp_sgd_analysis(q, sigma, steps, orders, delta): + """Compute and print results of DP-SGD analysis.""" + + # compute_rdp requires that sigma be the ratio of the standard deviation of + # the Gaussian noise to the l2-sensitivity of the function to which it is + # added. Hence, sigma here corresponds to the `noise_multiplier` parameter + # in the DP-SGD implementation found in privacy.optimizers.dp_optimizer + rdp = compute_rdp(q, sigma, steps, orders) + + eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta) + + print('DP-SGD with sampling rate = {:.3g}% and noise_multiplier = {} iterated' + ' over {} steps satisfies'.format(100 * q, sigma, steps), end=' ') + print('differential privacy with eps = {:.3g} and delta = {}.'.format( + eps, delta)) + print('The optimal RDP order is {}.'.format(opt_order)) + + if opt_order == max(orders) or opt_order == min(orders): + print('The privacy estimate is likely to be improved by expanding ' + 'the set of orders.') + + return eps, opt_order + + +def compute_dp_sgd_privacy(n, batch_size, noise_multiplier, epochs, delta): + """Compute epsilon based on the given hyperparameters.""" + q = batch_size / n # q - the sampling ratio. + if q > 1: + raise app.UsageError('n must be larger than the batch size.') + orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + + list(range(5, 64)) + [128, 256, 512]) + steps = int(math.ceil(epochs * n / batch_size)) + + return apply_dp_sgd_analysis(q, noise_multiplier, steps, orders, delta) diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py index 96bfb00..7f4b66c 100644 --- a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_test.py @@ -20,7 +20,7 @@ from __future__ import print_function from absl.testing import absltest from absl.testing import parameterized -from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy +from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib class ComputeDpSgdPrivacyTest(parameterized.TestCase): @@ -32,7 +32,7 @@ class ComputeDpSgdPrivacyTest(parameterized.TestCase): ) def test_compute_dp_sgd_privacy(self, n, batch_size, noise_multiplier, epochs, delta, expected_eps, expected_order): - eps, order = compute_dp_sgd_privacy.compute_dp_sgd_privacy( + eps, order = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy( n, batch_size, noise_multiplier, epochs, delta) self.assertAlmostEqual(eps, expected_eps) self.assertAlmostEqual(order, expected_order) diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py index 0e884a5..e9cf3f8 100644 --- a/tutorials/mnist_dpsgd_tutorial.py +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -1,4 +1,4 @@ -# Copyright 2018, The TensorFlow Authors. +# Copyright 2020, The TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,26 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Training a CNN on MNIST with differentially private SGD optimizer.""" +"""Train a CNN on MNIST with differentially private SGD optimizer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from absl import app from absl import flags +from absl import logging -import numpy as np import tensorflow.compat.v1 as tf -from tensorflow_privacy.privacy.analysis import privacy_ledger -from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_from_ledger -from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib from tensorflow_privacy.privacy.optimizers import dp_optimizer - -GradientDescentOptimizer = tf.train.GradientDescentOptimizer - -FLAGS = flags.FLAGS +from tensorflow_privacy.tutorials import mnist_dpsgd_tutorial_common as common flags.DEFINE_boolean( 'dpsgd', True, 'If True, train with DP-SGD. If False, ' @@ -41,62 +38,20 @@ flags.DEFINE_float('noise_multiplier', 1.1, 'Ratio of the standard deviation to the clipping norm') flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') flags.DEFINE_integer('batch_size', 256, 'Batch size') -flags.DEFINE_integer('epochs', 60, 'Number of epochs') +flags.DEFINE_integer('epochs', 30, 'Number of epochs') flags.DEFINE_integer( 'microbatches', 256, 'Number of microbatches ' '(must evenly divide batch_size)') flags.DEFINE_string('model_dir', None, 'Model directory') - -class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook): - """Training hook to print current value of epsilon after an epoch.""" - - def __init__(self, ledger): - """Initalizes the EpsilonPrintingTrainingHook. - - Args: - ledger: The privacy ledger. - """ - self._samples, self._queries = ledger.get_unformatted_ledger() - - def end(self, session): - - # Any RDP order (for order > 1) corresponds to one epsilon value. We - # enumerate through a few orders and pick the one that gives lowest epsilon. - # The variable orders may be extended for different use cases. Usually, the - # search is set to be finer-grained for small orders and coarser-grained for - # larger orders. - orders = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)) - samples = session.run(self._samples) - queries = session.run(self._queries) - formatted_ledger = privacy_ledger.format_ledger(samples, queries) - rdp = compute_rdp_from_ledger(formatted_ledger, orders) - - # It is recommended that delta is o(1/dataset_size). In the case of MNIST, - # dataset_size is 60000, so we set delta to be 1e-5. For larger datasets, - # delta should be set smaller. - eps = get_privacy_spent(orders, rdp, target_delta=1e-5)[0] - print('For delta=1e-5, the current epsilon is: %.2f' % eps) +FLAGS = flags.FLAGS -def cnn_model_fn(features, labels, mode): +def cnn_model_fn(features, labels, mode, params): # pylint: disable=unused-argument """Model function for a CNN.""" - # Define CNN architecture using tf.keras.layers. - input_layer = tf.reshape(features['x'], [-1, 28, 28, 1]) - y = tf.keras.layers.Conv2D(16, 8, - strides=2, - padding='same', - activation='relu').apply(input_layer) - y = tf.keras.layers.MaxPool2D(2, 1).apply(y) - y = tf.keras.layers.Conv2D(32, 4, - strides=2, - padding='valid', - activation='relu').apply(y) - y = tf.keras.layers.MaxPool2D(2, 1).apply(y) - y = tf.keras.layers.Flatten().apply(y) - y = tf.keras.layers.Dense(32, activation='relu').apply(y) - logits = tf.keras.layers.Dense(10).apply(y) + # Define CNN architecture. + logits = common.get_cnn_model(features) # Calculate loss as a vector (to support microbatches in DP-SGD). vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -106,12 +61,7 @@ def cnn_model_fn(features, labels, mode): # Configure the training op (for TRAIN mode). if mode == tf.estimator.ModeKeys.TRAIN: - if FLAGS.dpsgd: - ledger = privacy_ledger.PrivacyLedger( - population_size=60000, - selection_probability=(FLAGS.batch_size / 60000)) - # Use DP version of GradientDescentOptimizer. Other optimizers are # available in dp_optimizer. Most optimizers inheriting from # tf.train.Optimizer should be wrappable in differentially private @@ -120,26 +70,22 @@ def cnn_model_fn(features, labels, mode): l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, num_microbatches=FLAGS.microbatches, - ledger=ledger, learning_rate=FLAGS.learning_rate) - training_hooks = [ - EpsilonPrintingTrainingHook(ledger) - ] opt_loss = vector_loss else: - optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) - training_hooks = [] + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=FLAGS.learning_rate) opt_loss = scalar_loss + global_step = tf.train.get_global_step() train_op = optimizer.minimize(loss=opt_loss, global_step=global_step) + # In the following, we pass the mean of the loss (scalar_loss) rather than # the vector_loss because tf.estimator requires a scalar loss. This is only # used for evaluation and debugging by tf.estimator. The actual loss being # minimized is opt_loss defined above and passed to optimizer.minimize(). - return tf.estimator.EstimatorSpec(mode=mode, - loss=scalar_loss, - train_op=train_op, - training_hooks=training_hooks) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) # Add evaluation metrics (for EVAL mode). elif mode == tf.estimator.ModeKeys.EVAL: @@ -149,69 +95,48 @@ def cnn_model_fn(features, labels, mode): labels=labels, predictions=tf.argmax(input=logits, axis=1)) } - return tf.estimator.EstimatorSpec(mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops) -def load_mnist(): - """Loads MNIST and preprocesses to combine training and validation data.""" - train, test = tf.keras.datasets.mnist.load_data() - train_data, train_labels = train - test_data, test_labels = test - - train_data = np.array(train_data, dtype=np.float32) / 255 - test_data = np.array(test_data, dtype=np.float32) / 255 - - train_labels = np.array(train_labels, dtype=np.int32) - test_labels = np.array(test_labels, dtype=np.int32) - - assert train_data.min() == 0. - assert train_data.max() == 1. - assert test_data.min() == 0. - assert test_data.max() == 1. - assert train_labels.ndim == 1 - assert test_labels.ndim == 1 - - return train_data, train_labels, test_data, test_labels - - def main(unused_argv): - tf.logging.set_verbosity(tf.logging.INFO) + logging.set_verbosity(logging.INFO) if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: raise ValueError('Number of microbatches should divide evenly batch_size') - # Load training and test data. - train_data, train_labels, test_data, test_labels = load_mnist() - # Instantiate the tf.Estimator. mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=FLAGS.model_dir) - # Create tf.Estimator input functions for the training and test data. - train_input_fn = tf.estimator.inputs.numpy_input_fn( - x={'x': train_data}, - y=train_labels, - batch_size=FLAGS.batch_size, - num_epochs=FLAGS.epochs, - shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( - x={'x': test_data}, - y=test_labels, - num_epochs=1, - shuffle=False) - # Training loop. steps_per_epoch = 60000 // FLAGS.batch_size for epoch in range(1, FLAGS.epochs + 1): + start_time = time.time() # Train the model for one epoch. - mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch) + mnist_classifier.train( + input_fn=common.make_input_fn('train', FLAGS.batch_size), + steps=steps_per_epoch) + end_time = time.time() + logging.info('Epoch %d time in seconds: %.2f', epoch, end_time - start_time) # Evaluate the model and print results - eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) + eval_results = mnist_classifier.evaluate( + input_fn=common.make_input_fn('test', FLAGS.batch_size, 1)) test_accuracy = eval_results['accuracy'] print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) + # Compute the privacy budget expended. + if FLAGS.dpsgd: + if FLAGS.noise_multiplier > 0.0: + eps, _ = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy( + 60000, FLAGS.batch_size, FLAGS.noise_multiplier, epoch, 1e-5) + print('For delta=1e-5, the current epsilon is: %.2f' % eps) + else: + print('Trained with DP-SGD but with zero noise.') + else: + print('Trained with vanilla non-private SGD optimizer') + + if __name__ == '__main__': app.run(main) diff --git a/tutorials/mnist_dpsgd_tutorial_common.py b/tutorials/mnist_dpsgd_tutorial_common.py new file mode 100644 index 0000000..656ad54 --- /dev/null +++ b/tutorials/mnist_dpsgd_tutorial_common.py @@ -0,0 +1,75 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common tools for DP-SGD MNIST tutorials.""" + +# These are not necessary in a Python 3-only module. +from __future__ import absolute_import +from __future__ import division +from __future__ import google_type_annotations +from __future__ import print_function + +import tensorflow.compat.v1 as tf +import tensorflow_datasets as tfds + + +def get_cnn_model(features): + """Given input features, returns the logits from a simple CNN model.""" + input_layer = tf.reshape(features, [-1, 28, 28, 1]) + y = tf.keras.layers.Conv2D( + 16, 8, strides=2, padding='same', activation='relu').apply(input_layer) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Conv2D( + 32, 4, strides=2, padding='valid', activation='relu').apply(y) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Flatten().apply(y) + y = tf.keras.layers.Dense(32, activation='relu').apply(y) + logits = tf.keras.layers.Dense(10).apply(y) + + return logits + + +def make_input_fn(split, input_batch_size=256, repetitions=-1, tpu=False): + """Make input function on given MNIST split.""" + + def input_fn(params=None): + """A simple input function.""" + batch_size = params.get('batch_size', input_batch_size) + + def parser(example): + image, label = example['image'], example['label'] + image = tf.cast(image, tf.float32) + image /= 255.0 + label = tf.cast(label, tf.int32) + return image, label + + dataset = tfds.load(name='mnist', split=split) + dataset = dataset.map(parser).shuffle(60000).repeat(repetitions).batch( + batch_size) + # If this input function is not meant for TPUs, we can stop here. + # Otherwise, we need to explicitly set its shape. Note that for unknown + # reasons, returning the latter format causes performance regression + # on non-TPUs. + if not tpu: + return dataset + + # Give inputs statically known shapes; needed for TPUs. + images, labels = tf.data.make_one_shot_iterator(dataset).get_next() + # return images, labels + images.set_shape([batch_size, 28, 28, 1]) + labels.set_shape([ + batch_size, + ]) + return images, labels + + return input_fn diff --git a/tutorials/mnist_dpsgd_tutorial_tpu.py b/tutorials/mnist_dpsgd_tutorial_tpu.py new file mode 100644 index 0000000..e9def32 --- /dev/null +++ b/tutorials/mnist_dpsgd_tutorial_tpu.py @@ -0,0 +1,167 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train a CNN on MNIST with DP-SGD optimizer on TPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import time + +from absl import app +from absl import flags +from absl import logging + +import tensorflow.compat.v1 as tf + +from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib +from tensorflow_privacy.privacy.optimizers import dp_optimizer +from tensorflow_privacy.tutorials import mnist_dpsgd_tutorial_common as common + +flags.DEFINE_boolean( + 'dpsgd', True, 'If True, train with DP-SGD. If False, ' + 'train with vanilla SGD.') +flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') +flags.DEFINE_float('noise_multiplier', 0.77, + 'Ratio of the standard deviation to the clipping norm') +flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') +flags.DEFINE_integer('batch_size', 200, 'Batch size') +flags.DEFINE_integer('cores', 2, 'Number of TPU cores') +flags.DEFINE_integer('epochs', 60, 'Number of epochs') +flags.DEFINE_integer( + 'microbatches', 100, 'Number of microbatches ' + '(must evenly divide batch_size / cores)') +flags.DEFINE_string('model_dir', None, 'Model directory') +flags.DEFINE_string('master', None, 'Master') + +FLAGS = flags.FLAGS + + +def cnn_model_fn(features, labels, mode, params): # pylint: disable=unused-argument + """Model function for a CNN.""" + + # Define CNN architecture using tf.keras.layers. + logits = common.get_cnn_model(features) + + # Calculate loss as a vector (to support microbatches in DP-SGD). + vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + # Define mean of loss across minibatch (for reporting through tf.Estimator). + scalar_loss = tf.reduce_mean(input_tensor=vector_loss) + + # Configure the training op (for TRAIN mode). + if mode == tf.estimator.ModeKeys.TRAIN: + if FLAGS.dpsgd: + # Use DP version of GradientDescentOptimizer. Other optimizers are + # available in dp_optimizer. Most optimizers inheriting from + # tf.train.Optimizer should be wrappable in differentially private + # counterparts by calling dp_optimizer.optimizer_from_args(). + optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer( + l2_norm_clip=FLAGS.l2_norm_clip, + noise_multiplier=FLAGS.noise_multiplier, + num_microbatches=FLAGS.microbatches, + learning_rate=FLAGS.learning_rate) + opt_loss = vector_loss + else: + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=FLAGS.learning_rate) + opt_loss = scalar_loss + + # Training with TPUs requires wrapping the optimizer in a + # CrossShardOptimizer. + optimizer = tf.tpu.CrossShardOptimizer(optimizer) + + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss=opt_loss, global_step=global_step) + + # In the following, we pass the mean of the loss (scalar_loss) rather than + # the vector_loss because tf.estimator requires a scalar loss. This is only + # used for evaluation and debugging by tf.estimator. The actual loss being + # minimized is opt_loss defined above and passed to optimizer.minimize(). + return tf.estimator.tpu.TPUEstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) + + # Add evaluation metrics (for EVAL mode). + elif mode == tf.estimator.ModeKeys.EVAL: + + def metric_fn(labels, logits): + predictions = tf.argmax(logits, 1) + return { + 'accuracy': + tf.metrics.accuracy(labels=labels, predictions=predictions), + } + + return tf.estimator.tpu.TPUEstimatorSpec( + mode=mode, + loss=scalar_loss, + eval_metrics=(metric_fn, { + 'labels': labels, + 'logits': logits, + })) + + +def main(unused_argv): + logging.set_verbosity(logging.INFO) + if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: + raise ValueError('Number of microbatches should divide evenly batch_size') + + # Instantiate the tf.Estimator. + run_config = tf.estimator.tpu.RunConfig(master=FLAGS.master) + mnist_classifier = tf.estimator.tpu.TPUEstimator( + train_batch_size=FLAGS.batch_size, + eval_batch_size=FLAGS.batch_size, + model_fn=cnn_model_fn, + model_dir=FLAGS.model_dir, + config=run_config) + + # Training loop. + steps_per_epoch = 60000 // FLAGS.batch_size + eval_steps_per_epoch = 10000 // FLAGS.batch_size + for epoch in range(1, FLAGS.epochs + 1): + start_time = time.time() + # Train the model for one epoch. + mnist_classifier.train( + input_fn=common.make_input_fn( + 'train', FLAGS.batch_size / FLAGS.cores, tpu=True), + steps=steps_per_epoch) + end_time = time.time() + logging.info('Epoch %d time in seconds: %.2f', epoch, end_time - start_time) + + # Evaluate the model and print results + eval_results = mnist_classifier.evaluate( + input_fn=common.make_input_fn( + 'test', FLAGS.batch_size / FLAGS.cores, 1, tpu=True), + steps=eval_steps_per_epoch) + test_accuracy = eval_results['accuracy'] + print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) + + # Compute the privacy budget expended. + if FLAGS.dpsgd: + if FLAGS.noise_multiplier > 0.0: + # Due to the nature of Gaussian noise, the actual noise applied is + # equal to FLAGS.noise_multiplier * sqrt(number of cores). + eps, _ = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy( + 60000, FLAGS.batch_size, + FLAGS.noise_multiplier * math.sqrt(FLAGS.cores), epoch, 1e-5) + print('For delta=1e-5, the current epsilon is: %.2f' % eps) + else: + print('Trained with DP-SGD but with zero noise.') + else: + print('Trained with vanilla non-private SGD optimizer') + + +if __name__ == '__main__': + app.run(main)