forked from 626_privacy/tensorflow_privacy
Refactor MNIST tutorials and create new TPU tutorial:
1. Move common code to new file mnist_dpsgd_tutorial_common.py. 2. Move epsilon computation function out of binary into its own library. 3. Create new TPU tutorial. PiperOrigin-RevId: 310409308
This commit is contained in:
parent
164a57546a
commit
10335f6177
6 changed files with 351 additions and 156 deletions
|
@ -32,18 +32,16 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import math
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
|
||||||
|
|
||||||
# Opting out of loading all sibling packages and their dependencies.
|
# Opting out of loading all sibling packages and their dependencies.
|
||||||
sys.skip_tf_privacy_import = True
|
sys.skip_tf_privacy_import = True
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top
|
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
flags.DEFINE_integer('N', None, 'Total number of examples')
|
flags.DEFINE_integer('N', None, 'Total number of examples')
|
||||||
|
@ -53,42 +51,6 @@ flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)')
|
||||||
flags.DEFINE_float('delta', 1e-6, 'Target delta')
|
flags.DEFINE_float('delta', 1e-6, 'Target delta')
|
||||||
|
|
||||||
|
|
||||||
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
|
|
||||||
"""Compute and print results of DP-SGD analysis."""
|
|
||||||
|
|
||||||
# compute_rdp requires that sigma be the ratio of the standard deviation of
|
|
||||||
# the Gaussian noise to the l2-sensitivity of the function to which it is
|
|
||||||
# added. Hence, sigma here corresponds to the `noise_multiplier` parameter
|
|
||||||
# in the DP-SGD implementation found in privacy.optimizers.dp_optimizer
|
|
||||||
rdp = compute_rdp(q, sigma, steps, orders)
|
|
||||||
|
|
||||||
eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta)
|
|
||||||
|
|
||||||
print('DP-SGD with sampling rate = {:.3g}% and noise_multiplier = {} iterated'
|
|
||||||
' over {} steps satisfies'.format(100 * q, sigma, steps), end=' ')
|
|
||||||
print('differential privacy with eps = {:.3g} and delta = {}.'.format(
|
|
||||||
eps, delta))
|
|
||||||
print('The optimal RDP order is {}.'.format(opt_order))
|
|
||||||
|
|
||||||
if opt_order == max(orders) or opt_order == min(orders):
|
|
||||||
print('The privacy estimate is likely to be improved by expanding '
|
|
||||||
'the set of orders.')
|
|
||||||
|
|
||||||
return eps, opt_order
|
|
||||||
|
|
||||||
|
|
||||||
def compute_dp_sgd_privacy(n, batch_size, noise_multiplier, epochs, delta):
|
|
||||||
"""Compute epsilon based on the given hyperparameters."""
|
|
||||||
q = batch_size / n # q - the sampling ratio.
|
|
||||||
if q > 1:
|
|
||||||
raise app.UsageError('n must be larger than the batch size.')
|
|
||||||
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
|
||||||
list(range(5, 64)) + [128, 256, 512])
|
|
||||||
steps = int(math.ceil(epochs * n / batch_size))
|
|
||||||
|
|
||||||
return apply_dp_sgd_analysis(q, noise_multiplier, steps, orders, delta)
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
del argv # argv is not used.
|
del argv # argv is not used.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Library for computing privacy values for DP-SGD."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
|
||||||
|
# Opting out of loading all sibling packages and their dependencies.
|
||||||
|
sys.skip_tf_privacy_import = True
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top
|
||||||
|
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
|
|
||||||
|
|
||||||
|
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
|
||||||
|
"""Compute and print results of DP-SGD analysis."""
|
||||||
|
|
||||||
|
# compute_rdp requires that sigma be the ratio of the standard deviation of
|
||||||
|
# the Gaussian noise to the l2-sensitivity of the function to which it is
|
||||||
|
# added. Hence, sigma here corresponds to the `noise_multiplier` parameter
|
||||||
|
# in the DP-SGD implementation found in privacy.optimizers.dp_optimizer
|
||||||
|
rdp = compute_rdp(q, sigma, steps, orders)
|
||||||
|
|
||||||
|
eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta)
|
||||||
|
|
||||||
|
print('DP-SGD with sampling rate = {:.3g}% and noise_multiplier = {} iterated'
|
||||||
|
' over {} steps satisfies'.format(100 * q, sigma, steps), end=' ')
|
||||||
|
print('differential privacy with eps = {:.3g} and delta = {}.'.format(
|
||||||
|
eps, delta))
|
||||||
|
print('The optimal RDP order is {}.'.format(opt_order))
|
||||||
|
|
||||||
|
if opt_order == max(orders) or opt_order == min(orders):
|
||||||
|
print('The privacy estimate is likely to be improved by expanding '
|
||||||
|
'the set of orders.')
|
||||||
|
|
||||||
|
return eps, opt_order
|
||||||
|
|
||||||
|
|
||||||
|
def compute_dp_sgd_privacy(n, batch_size, noise_multiplier, epochs, delta):
|
||||||
|
"""Compute epsilon based on the given hyperparameters."""
|
||||||
|
q = batch_size / n # q - the sampling ratio.
|
||||||
|
if q > 1:
|
||||||
|
raise app.UsageError('n must be larger than the batch size.')
|
||||||
|
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
||||||
|
list(range(5, 64)) + [128, 256, 512])
|
||||||
|
steps = int(math.ceil(epochs * n / batch_size))
|
||||||
|
|
||||||
|
return apply_dp_sgd_analysis(q, noise_multiplier, steps, orders, delta)
|
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy
|
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
|
||||||
|
|
||||||
|
|
||||||
class ComputeDpSgdPrivacyTest(parameterized.TestCase):
|
class ComputeDpSgdPrivacyTest(parameterized.TestCase):
|
||||||
|
@ -32,7 +32,7 @@ class ComputeDpSgdPrivacyTest(parameterized.TestCase):
|
||||||
)
|
)
|
||||||
def test_compute_dp_sgd_privacy(self, n, batch_size, noise_multiplier, epochs,
|
def test_compute_dp_sgd_privacy(self, n, batch_size, noise_multiplier, epochs,
|
||||||
delta, expected_eps, expected_order):
|
delta, expected_eps, expected_order):
|
||||||
eps, order = compute_dp_sgd_privacy.compute_dp_sgd_privacy(
|
eps, order = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
|
||||||
n, batch_size, noise_multiplier, epochs, delta)
|
n, batch_size, noise_multiplier, epochs, delta)
|
||||||
self.assertAlmostEqual(eps, expected_eps)
|
self.assertAlmostEqual(eps, expected_eps)
|
||||||
self.assertAlmostEqual(order, expected_order)
|
self.assertAlmostEqual(order, expected_order)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,26 +12,23 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Training a CNN on MNIST with differentially private SGD optimizer."""
|
"""Train a CNN on MNIST with differentially private SGD optimizer."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.analysis import privacy_ledger
|
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_from_ledger
|
|
||||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
|
||||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||||
|
from tensorflow_privacy.tutorials import mnist_dpsgd_tutorial_common as common
|
||||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
flags.DEFINE_boolean(
|
||||||
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||||
|
@ -41,62 +38,20 @@ flags.DEFINE_float('noise_multiplier', 1.1,
|
||||||
'Ratio of the standard deviation to the clipping norm')
|
'Ratio of the standard deviation to the clipping norm')
|
||||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||||
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||||
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
flags.DEFINE_integer('epochs', 30, 'Number of epochs')
|
||||||
flags.DEFINE_integer(
|
flags.DEFINE_integer(
|
||||||
'microbatches', 256, 'Number of microbatches '
|
'microbatches', 256, 'Number of microbatches '
|
||||||
'(must evenly divide batch_size)')
|
'(must evenly divide batch_size)')
|
||||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
class EpsilonPrintingTrainingHook(tf.estimator.SessionRunHook):
|
|
||||||
"""Training hook to print current value of epsilon after an epoch."""
|
|
||||||
|
|
||||||
def __init__(self, ledger):
|
|
||||||
"""Initalizes the EpsilonPrintingTrainingHook.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ledger: The privacy ledger.
|
|
||||||
"""
|
|
||||||
self._samples, self._queries = ledger.get_unformatted_ledger()
|
|
||||||
|
|
||||||
def end(self, session):
|
|
||||||
|
|
||||||
# Any RDP order (for order > 1) corresponds to one epsilon value. We
|
|
||||||
# enumerate through a few orders and pick the one that gives lowest epsilon.
|
|
||||||
# The variable orders may be extended for different use cases. Usually, the
|
|
||||||
# search is set to be finer-grained for small orders and coarser-grained for
|
|
||||||
# larger orders.
|
|
||||||
orders = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
|
|
||||||
samples = session.run(self._samples)
|
|
||||||
queries = session.run(self._queries)
|
|
||||||
formatted_ledger = privacy_ledger.format_ledger(samples, queries)
|
|
||||||
rdp = compute_rdp_from_ledger(formatted_ledger, orders)
|
|
||||||
|
|
||||||
# It is recommended that delta is o(1/dataset_size). In the case of MNIST,
|
|
||||||
# dataset_size is 60000, so we set delta to be 1e-5. For larger datasets,
|
|
||||||
# delta should be set smaller.
|
|
||||||
eps = get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
|
||||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
|
||||||
|
|
||||||
|
|
||||||
def cnn_model_fn(features, labels, mode):
|
def cnn_model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||||
"""Model function for a CNN."""
|
"""Model function for a CNN."""
|
||||||
|
|
||||||
# Define CNN architecture using tf.keras.layers.
|
# Define CNN architecture.
|
||||||
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
|
logits = common.get_cnn_model(features)
|
||||||
y = tf.keras.layers.Conv2D(16, 8,
|
|
||||||
strides=2,
|
|
||||||
padding='same',
|
|
||||||
activation='relu').apply(input_layer)
|
|
||||||
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
|
||||||
y = tf.keras.layers.Conv2D(32, 4,
|
|
||||||
strides=2,
|
|
||||||
padding='valid',
|
|
||||||
activation='relu').apply(y)
|
|
||||||
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
|
||||||
y = tf.keras.layers.Flatten().apply(y)
|
|
||||||
y = tf.keras.layers.Dense(32, activation='relu').apply(y)
|
|
||||||
logits = tf.keras.layers.Dense(10).apply(y)
|
|
||||||
|
|
||||||
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
||||||
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||||
|
@ -106,12 +61,7 @@ def cnn_model_fn(features, labels, mode):
|
||||||
|
|
||||||
# Configure the training op (for TRAIN mode).
|
# Configure the training op (for TRAIN mode).
|
||||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
|
|
||||||
if FLAGS.dpsgd:
|
if FLAGS.dpsgd:
|
||||||
ledger = privacy_ledger.PrivacyLedger(
|
|
||||||
population_size=60000,
|
|
||||||
selection_probability=(FLAGS.batch_size / 60000))
|
|
||||||
|
|
||||||
# Use DP version of GradientDescentOptimizer. Other optimizers are
|
# Use DP version of GradientDescentOptimizer. Other optimizers are
|
||||||
# available in dp_optimizer. Most optimizers inheriting from
|
# available in dp_optimizer. Most optimizers inheriting from
|
||||||
# tf.train.Optimizer should be wrappable in differentially private
|
# tf.train.Optimizer should be wrappable in differentially private
|
||||||
|
@ -120,26 +70,22 @@ def cnn_model_fn(features, labels, mode):
|
||||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||||
noise_multiplier=FLAGS.noise_multiplier,
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
num_microbatches=FLAGS.microbatches,
|
num_microbatches=FLAGS.microbatches,
|
||||||
ledger=ledger,
|
|
||||||
learning_rate=FLAGS.learning_rate)
|
learning_rate=FLAGS.learning_rate)
|
||||||
training_hooks = [
|
|
||||||
EpsilonPrintingTrainingHook(ledger)
|
|
||||||
]
|
|
||||||
opt_loss = vector_loss
|
opt_loss = vector_loss
|
||||||
else:
|
else:
|
||||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
optimizer = tf.train.GradientDescentOptimizer(
|
||||||
training_hooks = []
|
learning_rate=FLAGS.learning_rate)
|
||||||
opt_loss = scalar_loss
|
opt_loss = scalar_loss
|
||||||
|
|
||||||
global_step = tf.train.get_global_step()
|
global_step = tf.train.get_global_step()
|
||||||
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
||||||
|
|
||||||
# In the following, we pass the mean of the loss (scalar_loss) rather than
|
# In the following, we pass the mean of the loss (scalar_loss) rather than
|
||||||
# the vector_loss because tf.estimator requires a scalar loss. This is only
|
# the vector_loss because tf.estimator requires a scalar loss. This is only
|
||||||
# used for evaluation and debugging by tf.estimator. The actual loss being
|
# used for evaluation and debugging by tf.estimator. The actual loss being
|
||||||
# minimized is opt_loss defined above and passed to optimizer.minimize().
|
# minimized is opt_loss defined above and passed to optimizer.minimize().
|
||||||
return tf.estimator.EstimatorSpec(mode=mode,
|
return tf.estimator.EstimatorSpec(
|
||||||
loss=scalar_loss,
|
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||||
train_op=train_op,
|
|
||||||
training_hooks=training_hooks)
|
|
||||||
|
|
||||||
# Add evaluation metrics (for EVAL mode).
|
# Add evaluation metrics (for EVAL mode).
|
||||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
@ -149,69 +95,48 @@ def cnn_model_fn(features, labels, mode):
|
||||||
labels=labels,
|
labels=labels,
|
||||||
predictions=tf.argmax(input=logits, axis=1))
|
predictions=tf.argmax(input=logits, axis=1))
|
||||||
}
|
}
|
||||||
|
|
||||||
return tf.estimator.EstimatorSpec(mode=mode,
|
return tf.estimator.EstimatorSpec(mode=mode,
|
||||||
loss=scalar_loss,
|
loss=scalar_loss,
|
||||||
eval_metric_ops=eval_metric_ops)
|
eval_metric_ops=eval_metric_ops)
|
||||||
|
|
||||||
|
|
||||||
def load_mnist():
|
|
||||||
"""Loads MNIST and preprocesses to combine training and validation data."""
|
|
||||||
train, test = tf.keras.datasets.mnist.load_data()
|
|
||||||
train_data, train_labels = train
|
|
||||||
test_data, test_labels = test
|
|
||||||
|
|
||||||
train_data = np.array(train_data, dtype=np.float32) / 255
|
|
||||||
test_data = np.array(test_data, dtype=np.float32) / 255
|
|
||||||
|
|
||||||
train_labels = np.array(train_labels, dtype=np.int32)
|
|
||||||
test_labels = np.array(test_labels, dtype=np.int32)
|
|
||||||
|
|
||||||
assert train_data.min() == 0.
|
|
||||||
assert train_data.max() == 1.
|
|
||||||
assert test_data.min() == 0.
|
|
||||||
assert test_data.max() == 1.
|
|
||||||
assert train_labels.ndim == 1
|
|
||||||
assert test_labels.ndim == 1
|
|
||||||
|
|
||||||
return train_data, train_labels, test_data, test_labels
|
|
||||||
|
|
||||||
|
|
||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
logging.set_verbosity(logging.INFO)
|
||||||
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||||
|
|
||||||
# Load training and test data.
|
|
||||||
train_data, train_labels, test_data, test_labels = load_mnist()
|
|
||||||
|
|
||||||
# Instantiate the tf.Estimator.
|
# Instantiate the tf.Estimator.
|
||||||
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
|
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
|
||||||
model_dir=FLAGS.model_dir)
|
model_dir=FLAGS.model_dir)
|
||||||
|
|
||||||
# Create tf.Estimator input functions for the training and test data.
|
|
||||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
|
||||||
x={'x': train_data},
|
|
||||||
y=train_labels,
|
|
||||||
batch_size=FLAGS.batch_size,
|
|
||||||
num_epochs=FLAGS.epochs,
|
|
||||||
shuffle=True)
|
|
||||||
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
|
||||||
x={'x': test_data},
|
|
||||||
y=test_labels,
|
|
||||||
num_epochs=1,
|
|
||||||
shuffle=False)
|
|
||||||
|
|
||||||
# Training loop.
|
# Training loop.
|
||||||
steps_per_epoch = 60000 // FLAGS.batch_size
|
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||||
for epoch in range(1, FLAGS.epochs + 1):
|
for epoch in range(1, FLAGS.epochs + 1):
|
||||||
|
start_time = time.time()
|
||||||
# Train the model for one epoch.
|
# Train the model for one epoch.
|
||||||
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch)
|
mnist_classifier.train(
|
||||||
|
input_fn=common.make_input_fn('train', FLAGS.batch_size),
|
||||||
|
steps=steps_per_epoch)
|
||||||
|
end_time = time.time()
|
||||||
|
logging.info('Epoch %d time in seconds: %.2f', epoch, end_time - start_time)
|
||||||
|
|
||||||
# Evaluate the model and print results
|
# Evaluate the model and print results
|
||||||
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
|
eval_results = mnist_classifier.evaluate(
|
||||||
|
input_fn=common.make_input_fn('test', FLAGS.batch_size, 1))
|
||||||
test_accuracy = eval_results['accuracy']
|
test_accuracy = eval_results['accuracy']
|
||||||
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||||
|
|
||||||
|
# Compute the privacy budget expended.
|
||||||
|
if FLAGS.dpsgd:
|
||||||
|
if FLAGS.noise_multiplier > 0.0:
|
||||||
|
eps, _ = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
|
||||||
|
60000, FLAGS.batch_size, FLAGS.noise_multiplier, epoch, 1e-5)
|
||||||
|
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||||
|
else:
|
||||||
|
print('Trained with DP-SGD but with zero noise.')
|
||||||
|
else:
|
||||||
|
print('Trained with vanilla non-private SGD optimizer')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(main)
|
app.run(main)
|
||||||
|
|
75
tutorials/mnist_dpsgd_tutorial_common.py
Normal file
75
tutorials/mnist_dpsgd_tutorial_common.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Common tools for DP-SGD MNIST tutorials."""
|
||||||
|
|
||||||
|
# These are not necessary in a Python 3-only module.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import google_type_annotations
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
|
|
||||||
|
def get_cnn_model(features):
|
||||||
|
"""Given input features, returns the logits from a simple CNN model."""
|
||||||
|
input_layer = tf.reshape(features, [-1, 28, 28, 1])
|
||||||
|
y = tf.keras.layers.Conv2D(
|
||||||
|
16, 8, strides=2, padding='same', activation='relu').apply(input_layer)
|
||||||
|
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||||
|
y = tf.keras.layers.Conv2D(
|
||||||
|
32, 4, strides=2, padding='valid', activation='relu').apply(y)
|
||||||
|
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||||
|
y = tf.keras.layers.Flatten().apply(y)
|
||||||
|
y = tf.keras.layers.Dense(32, activation='relu').apply(y)
|
||||||
|
logits = tf.keras.layers.Dense(10).apply(y)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def make_input_fn(split, input_batch_size=256, repetitions=-1, tpu=False):
|
||||||
|
"""Make input function on given MNIST split."""
|
||||||
|
|
||||||
|
def input_fn(params=None):
|
||||||
|
"""A simple input function."""
|
||||||
|
batch_size = params.get('batch_size', input_batch_size)
|
||||||
|
|
||||||
|
def parser(example):
|
||||||
|
image, label = example['image'], example['label']
|
||||||
|
image = tf.cast(image, tf.float32)
|
||||||
|
image /= 255.0
|
||||||
|
label = tf.cast(label, tf.int32)
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
dataset = tfds.load(name='mnist', split=split)
|
||||||
|
dataset = dataset.map(parser).shuffle(60000).repeat(repetitions).batch(
|
||||||
|
batch_size)
|
||||||
|
# If this input function is not meant for TPUs, we can stop here.
|
||||||
|
# Otherwise, we need to explicitly set its shape. Note that for unknown
|
||||||
|
# reasons, returning the latter format causes performance regression
|
||||||
|
# on non-TPUs.
|
||||||
|
if not tpu:
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
# Give inputs statically known shapes; needed for TPUs.
|
||||||
|
images, labels = tf.data.make_one_shot_iterator(dataset).get_next()
|
||||||
|
# return images, labels
|
||||||
|
images.set_shape([batch_size, 28, 28, 1])
|
||||||
|
labels.set_shape([
|
||||||
|
batch_size,
|
||||||
|
])
|
||||||
|
return images, labels
|
||||||
|
|
||||||
|
return input_fn
|
167
tutorials/mnist_dpsgd_tutorial_tpu.py
Normal file
167
tutorials/mnist_dpsgd_tutorial_tpu.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Train a CNN on MNIST with DP-SGD optimizer on TPUs."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
|
||||||
|
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||||
|
from tensorflow_privacy.tutorials import mnist_dpsgd_tutorial_common as common
|
||||||
|
|
||||||
|
flags.DEFINE_boolean(
|
||||||
|
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||||
|
'train with vanilla SGD.')
|
||||||
|
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||||
|
flags.DEFINE_float('noise_multiplier', 0.77,
|
||||||
|
'Ratio of the standard deviation to the clipping norm')
|
||||||
|
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||||
|
flags.DEFINE_integer('batch_size', 200, 'Batch size')
|
||||||
|
flags.DEFINE_integer('cores', 2, 'Number of TPU cores')
|
||||||
|
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
'microbatches', 100, 'Number of microbatches '
|
||||||
|
'(must evenly divide batch_size / cores)')
|
||||||
|
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||||
|
flags.DEFINE_string('master', None, 'Master')
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def cnn_model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||||
|
"""Model function for a CNN."""
|
||||||
|
|
||||||
|
# Define CNN architecture using tf.keras.layers.
|
||||||
|
logits = common.get_cnn_model(features)
|
||||||
|
|
||||||
|
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
||||||
|
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||||
|
labels=labels, logits=logits)
|
||||||
|
# Define mean of loss across minibatch (for reporting through tf.Estimator).
|
||||||
|
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||||
|
|
||||||
|
# Configure the training op (for TRAIN mode).
|
||||||
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
|
if FLAGS.dpsgd:
|
||||||
|
# Use DP version of GradientDescentOptimizer. Other optimizers are
|
||||||
|
# available in dp_optimizer. Most optimizers inheriting from
|
||||||
|
# tf.train.Optimizer should be wrappable in differentially private
|
||||||
|
# counterparts by calling dp_optimizer.optimizer_from_args().
|
||||||
|
optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer(
|
||||||
|
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||||
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
|
num_microbatches=FLAGS.microbatches,
|
||||||
|
learning_rate=FLAGS.learning_rate)
|
||||||
|
opt_loss = vector_loss
|
||||||
|
else:
|
||||||
|
optimizer = tf.train.GradientDescentOptimizer(
|
||||||
|
learning_rate=FLAGS.learning_rate)
|
||||||
|
opt_loss = scalar_loss
|
||||||
|
|
||||||
|
# Training with TPUs requires wrapping the optimizer in a
|
||||||
|
# CrossShardOptimizer.
|
||||||
|
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
|
||||||
|
|
||||||
|
global_step = tf.train.get_global_step()
|
||||||
|
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
||||||
|
|
||||||
|
# In the following, we pass the mean of the loss (scalar_loss) rather than
|
||||||
|
# the vector_loss because tf.estimator requires a scalar loss. This is only
|
||||||
|
# used for evaluation and debugging by tf.estimator. The actual loss being
|
||||||
|
# minimized is opt_loss defined above and passed to optimizer.minimize().
|
||||||
|
return tf.estimator.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||||
|
|
||||||
|
# Add evaluation metrics (for EVAL mode).
|
||||||
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
|
||||||
|
def metric_fn(labels, logits):
|
||||||
|
predictions = tf.argmax(logits, 1)
|
||||||
|
return {
|
||||||
|
'accuracy':
|
||||||
|
tf.metrics.accuracy(labels=labels, predictions=predictions),
|
||||||
|
}
|
||||||
|
|
||||||
|
return tf.estimator.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
loss=scalar_loss,
|
||||||
|
eval_metrics=(metric_fn, {
|
||||||
|
'labels': labels,
|
||||||
|
'logits': logits,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def main(unused_argv):
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||||
|
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||||
|
|
||||||
|
# Instantiate the tf.Estimator.
|
||||||
|
run_config = tf.estimator.tpu.RunConfig(master=FLAGS.master)
|
||||||
|
mnist_classifier = tf.estimator.tpu.TPUEstimator(
|
||||||
|
train_batch_size=FLAGS.batch_size,
|
||||||
|
eval_batch_size=FLAGS.batch_size,
|
||||||
|
model_fn=cnn_model_fn,
|
||||||
|
model_dir=FLAGS.model_dir,
|
||||||
|
config=run_config)
|
||||||
|
|
||||||
|
# Training loop.
|
||||||
|
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||||
|
eval_steps_per_epoch = 10000 // FLAGS.batch_size
|
||||||
|
for epoch in range(1, FLAGS.epochs + 1):
|
||||||
|
start_time = time.time()
|
||||||
|
# Train the model for one epoch.
|
||||||
|
mnist_classifier.train(
|
||||||
|
input_fn=common.make_input_fn(
|
||||||
|
'train', FLAGS.batch_size / FLAGS.cores, tpu=True),
|
||||||
|
steps=steps_per_epoch)
|
||||||
|
end_time = time.time()
|
||||||
|
logging.info('Epoch %d time in seconds: %.2f', epoch, end_time - start_time)
|
||||||
|
|
||||||
|
# Evaluate the model and print results
|
||||||
|
eval_results = mnist_classifier.evaluate(
|
||||||
|
input_fn=common.make_input_fn(
|
||||||
|
'test', FLAGS.batch_size / FLAGS.cores, 1, tpu=True),
|
||||||
|
steps=eval_steps_per_epoch)
|
||||||
|
test_accuracy = eval_results['accuracy']
|
||||||
|
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||||
|
|
||||||
|
# Compute the privacy budget expended.
|
||||||
|
if FLAGS.dpsgd:
|
||||||
|
if FLAGS.noise_multiplier > 0.0:
|
||||||
|
# Due to the nature of Gaussian noise, the actual noise applied is
|
||||||
|
# equal to FLAGS.noise_multiplier * sqrt(number of cores).
|
||||||
|
eps, _ = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
|
||||||
|
60000, FLAGS.batch_size,
|
||||||
|
FLAGS.noise_multiplier * math.sqrt(FLAGS.cores), epoch, 1e-5)
|
||||||
|
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||||
|
else:
|
||||||
|
print('Trained with DP-SGD but with zero noise.')
|
||||||
|
else:
|
||||||
|
print('Trained with vanilla non-private SGD optimizer')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
Loading…
Reference in a new issue