Adding privacy analysis to the Logistic Regression for MNIST tutorial.

PiperOrigin-RevId: 254815428
This commit is contained in:
Ilya Mironov 2019-06-24 12:49:48 -07:00 committed by A. Unique TensorFlower
parent 2b97c7c735
commit 45bcb3a0e4
2 changed files with 92 additions and 57 deletions

View file

@ -20,6 +20,10 @@ Here is a list of all the tutorials included:
* `mnist_dpsgd_tutorial_keras.py`: learn a convolutional neural network on MNIST * `mnist_dpsgd_tutorial_keras.py`: learn a convolutional neural network on MNIST
with differential privacy using tf.Keras. with differential privacy using tf.Keras.
* `mnist_lr_tutorial.py`: learn a differentially private logistic regression
model on MNIST. The model illustrates application of the
"amplification-by-iteration" analysis (https://arxiv.org/abs/1808.06651).
The rest of this README describes the different parameters used to configure The rest of this README describes the different parameters used to configure
DP-SGD as well as expected outputs for the `mnist_dpsgd_tutorial.py` tutorial. DP-SGD as well as expected outputs for the `mnist_dpsgd_tutorial.py` tutorial.

View file

@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""DP Logistic Regression on MNIST. """DP Logistic Regression on MNIST.
DP Logistic Regression on MNIST with support for privacy-by-iteration analysis. DP Logistic Regression on MNIST with support for privacy-by-iteration analysis.
Feldman, Vitaly, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta. Vitaly Feldman, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta.
"Privacy amplification by iteration." "Privacy amplification by iteration."
In 2018 IEEE 59th Annual Symposium on Foundations of Computer Science (FOCS), In 2018 IEEE 59th Annual Symposium on Foundations of Computer Science (FOCS),
pp. 521-532. IEEE, 2018. pp. 521-532. IEEE, 2018.
@ -36,6 +35,8 @@ from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer from privacy.optimizers import dp_optimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
@ -45,32 +46,30 @@ else:
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, ' flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, '
'train with vanilla SGD.') 'train with vanilla SGD.')
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training') flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 0.02, flags.DEFINE_float('noise_multiplier', 0.05,
'Ratio of the standard deviation to the clipping norm') 'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_integer('batch_size', 1, 'Batch size') flags.DEFINE_integer('batch_size', 5, 'Batch size')
flags.DEFINE_integer('epochs', 5, 'Number of epochs') flags.DEFINE_integer('epochs', 5, 'Number of epochs')
flags.DEFINE_integer('microbatches', 1, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_float('regularizer', 0, 'L2 regularizer coefficient') flags.DEFINE_float('regularizer', 0, 'L2 regularizer coefficient')
flags.DEFINE_string('model_dir', None, 'Model directory') flags.DEFINE_string('model_dir', None, 'Model directory')
flags.DEFINE_float('data_l2_norm', 8, flags.DEFINE_float('data_l2_norm', 8, 'Bound on the L2 norm of normalized data')
'Bound on the L2 norm of normalized data.')
def lr_model_fn(features, labels, mode, nclasses, dim): def lr_model_fn(features, labels, mode, nclasses, dim):
"""Model function for logistic regression.""" """Model function for logistic regression."""
input_layer = tf.reshape(features['x'], tuple([-1]) + dim) input_layer = tf.reshape(features['x'], tuple([-1]) + dim)
logits = tf.layers.dense(inputs=input_layer, logits = tf.layers.dense(
inputs=input_layer,
units=nclasses, units=nclasses,
kernel_regularizer=tf.contrib.layers.l2_regularizer( kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=FLAGS.regularizer), scale=FLAGS.regularizer),
bias_regularizer=tf.contrib.layers.l2_regularizer( bias_regularizer=tf.contrib.layers.l2_regularizer(
scale=FLAGS.regularizer) scale=FLAGS.regularizer))
)
# Calculate loss as a vector (to support microbatches in DP-SGD). # Calculate loss as a vector (to support microbatches in DP-SGD).
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
@ -80,18 +79,15 @@ def lr_model_fn(features, labels, mode, nclasses, dim):
# Configure the training op (for TRAIN mode). # Configure the training op (for TRAIN mode).
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
if FLAGS.dpsgd: if FLAGS.dpsgd:
# Use DP version of GradientDescentOptimizer. Other optimizers are
# available in dp_optimizer. Most optimizers inheriting from
# tf.train.Optimizer should be wrappable in differentially private
# counterparts by calling dp_optimizer.optimizer_from_args().
# The loss function is L-Lipschitz with L = sqrt(2*(||x||^2 + 1)) where # The loss function is L-Lipschitz with L = sqrt(2*(||x||^2 + 1)) where
# ||x|| is the norm of the data. # ||x|| is the norm of the data.
# We don't use microbatches (thus speeding up computation), since no
# clipping is necessary due to data normalization.
optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer( optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer(
l2_norm_clip=math.sqrt(2*(FLAGS.data_l2_norm**2 + 1)), l2_norm_clip=math.sqrt(2 * (FLAGS.data_l2_norm**2 + 1)),
noise_multiplier=FLAGS.noise_multiplier, noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches, num_microbatches=1,
learning_rate=FLAGS.learning_rate) learning_rate=FLAGS.learning_rate)
opt_loss = vector_loss opt_loss = vector_loss
else: else:
@ -103,21 +99,18 @@ def lr_model_fn(features, labels, mode, nclasses, dim):
# the vector_loss because tf.estimator requires a scalar loss. This is only # the vector_loss because tf.estimator requires a scalar loss. This is only
# used for evaluation and debugging by tf.estimator. The actual loss being # used for evaluation and debugging by tf.estimator. The actual loss being
# minimized is opt_loss defined above and passed to optimizer.minimize(). # minimized is opt_loss defined above and passed to optimizer.minimize().
return tf.estimator.EstimatorSpec(mode=mode, return tf.estimator.EstimatorSpec(
loss=scalar_loss, mode=mode, loss=scalar_loss, train_op=train_op)
train_op=train_op)
# Add evaluation metrics (for EVAL mode). # Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL: elif mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = { eval_metric_ops = {
'accuracy': 'accuracy':
tf.metrics.accuracy( tf.metrics.accuracy(
labels=labels, labels=labels, predictions=tf.argmax(input=logits, axis=1))
predictions=tf.argmax(input=logits, axis=1))
} }
return tf.estimator.EstimatorSpec(mode=mode, return tf.estimator.EstimatorSpec(
loss=scalar_loss, mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops)
eval_metric_ops=eval_metric_ops)
def normalize_data(data, data_l2_norm): def normalize_data(data, data_l2_norm):
@ -159,14 +152,50 @@ def load_mnist(data_l2_norm=float('inf')):
return train_data, train_labels, test_data, test_labels return train_data, train_labels, test_data, test_labels
def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier):
"""Tabulating position-dependent privacy guarantees."""
if noise_multiplier == 0:
print('No differential privacy (additive noise is 0).')
return
print('In the conditions of Theorem 34 (https://arxiv.org/abs/1808.06651) '
'the training procedure results in the following privacy guarantees.')
print('Out of the total of {} samples:'.format(samples))
steps_per_epoch = samples // batch_size
orders = np.concatenate(
[np.linspace(2, 20, num=181),
np.linspace(20, 100, num=81)])
delta = 1e-5
for p in (.5, .9, .99):
steps = math.ceil(steps_per_epoch * p) # Steps in the last epoch.
coef = 2 * (noise_multiplier * batch_size)**-2 * (
# Accounting for privacy loss
(epochs - 1) / steps_per_epoch + # ... from all-but-last epochs
1 / (steps_per_epoch - steps + 1)) # ... due to the last epoch
# Using RDP accountant to compute eps. Doing computation analytically is
# an option.
rdp = [order * coef for order in orders]
eps, _, _ = get_privacy_spent(orders, rdp, target_delta=delta)
print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format(
p * 100, eps, delta))
# Compute privacy guarantees for the Sampled Gaussian Mechanism.
rdp_sgm = compute_rdp(batch_size / samples, noise_multiplier,
epochs * steps_per_epoch, orders)
eps_sgm, _, _ = get_privacy_spent(orders, rdp_sgm, target_delta=delta)
print('By comparison, DP-SGD analysis for training done with the same '
'parameters and random shuffling in each epoch guarantees '
'({:.2f}, {})-DP for all samples.'.format(eps_sgm, delta))
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
raise ValueError('Number of microbatches should divide evenly batch_size')
if FLAGS.data_l2_norm <= 0: if FLAGS.data_l2_norm <= 0:
raise ValueError('FLAGS.data_l2_norm needs to be positive.') raise ValueError('data_l2_norm must be positive.')
if FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2: if FLAGS.dpsgd and FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2:
raise ValueError('The amplification by iteration analysis requires' raise ValueError('The amplification-by-iteration analysis requires'
'learning_rate <= 2 / beta, where beta is the smoothness' 'learning_rate <= 2 / beta, where beta is the smoothness'
'of the loss function and is upper bounded by ||x||^2 / 4' 'of the loss function and is upper bounded by ||x||^2 / 4'
'with ||x|| being the largest L2 norm of the samples.') 'with ||x|| being the largest L2 norm of the samples.')
@ -178,15 +207,12 @@ def main(unused_argv):
train_data, train_labels, test_data, test_labels = load_mnist( train_data, train_labels, test_data, test_labels = load_mnist(
data_l2_norm=FLAGS.data_l2_norm) data_l2_norm=FLAGS.data_l2_norm)
# Instantiate the tf.Estimator. # Instantiate tf.Estimator.
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
model_fn = lambda features, labels, mode: lr_model_fn(features, labels, mode, model_fn = lambda features, labels, mode: lr_model_fn(
nclasses=10, features, labels, mode, nclasses=10, dim=train_data.shape[1:])
dim=train_data.shape[1:]
)
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=model_fn, model_fn=model_fn, model_dir=FLAGS.model_dir)
model_dir=FLAGS.model_dir)
# Create tf.Estimator input functions for the training and test data. # Create tf.Estimator input functions for the training and test data.
# To analyze the per-user privacy loss, we keep the same orders of samples in # To analyze the per-user privacy loss, we keep the same orders of samples in
@ -198,22 +224,27 @@ def main(unused_argv):
num_epochs=FLAGS.epochs, num_epochs=FLAGS.epochs,
shuffle=False) shuffle=False)
eval_input_fn = tf.estimator.inputs.numpy_input_fn( eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': test_data}, x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False)
y=test_labels,
num_epochs=1,
shuffle=False)
# Train the model # Train the model.
steps_per_epoch = train_data.shape[0] // FLAGS.batch_size num_samples = train_data.shape[0]
mnist_classifier.train(input_fn=train_input_fn, steps_per_epoch = num_samples // FLAGS.batch_size
steps=steps_per_epoch * FLAGS.epochs)
# Evaluate the model and print results mnist_classifier.train(
input_fn=train_input_fn, steps=steps_per_epoch * FLAGS.epochs)
# Evaluate the model and print results.
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
test_accuracy = eval_results['accuracy'] print('Test accuracy after {} epochs is: {:.2f}'.format(
print('Test accuracy after %d epochs is: %.3f' % (FLAGS.epochs, FLAGS.epochs, eval_results['accuracy']))
test_accuracy))
if FLAGS.dpsgd:
print_privacy_guarantees(
epochs=FLAGS.epochs,
batch_size=FLAGS.batch_size,
samples=num_samples,
noise_multiplier=FLAGS.noise_multiplier,
)
if __name__ == '__main__': if __name__ == '__main__':
app.run(main) app.run(main)