Merge pull request #6 from tensorflow/master

Merging original repo
This commit is contained in:
Christopher Choquette Choo 2019-08-06 10:57:01 -04:00 committed by GitHub
commit 136200d0c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 146 additions and 102 deletions

View file

@ -57,6 +57,10 @@ GitHub pull requests. To speed the code review process, we ask that:
your pull requests. In most cases this can be done by running `autopep8 -i your pull requests. In most cases this can be done by running `autopep8 -i
--indent-size 2 <file>` on the files you have edited. --indent-size 2 <file>` on the files you have edited.
* You should also check your code with pylint and TensorFlow's pylint
[configuration file](https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc)
by running `pylint --rcfile=/path/to/the/tf/rcfile <edited file.py>`.
* When making your first pull request, you * When making your first pull request, you
[sign the Google CLA](https://cla.developers.google.com/clas) [sign the Google CLA](https://cla.developers.google.com/clas)

View file

@ -9,6 +9,7 @@ py_library(
srcs = ["__init__.py"], srcs = ["__init__.py"],
deps = [ deps = [
"//third_party/py/tensorflow_privacy/privacy/analysis:privacy_ledger", "//third_party/py/tensorflow_privacy/privacy/analysis:privacy_ledger",
"//third_party/py/tensorflow_privacy/privacy/analysis:rdp_accountant",
"//third_party/py/tensorflow_privacy/privacy/dp_query", "//third_party/py/tensorflow_privacy/privacy/dp_query",
"//third_party/py/tensorflow_privacy/privacy/dp_query:gaussian_query", "//third_party/py/tensorflow_privacy/privacy/dp_query:gaussian_query",
"//third_party/py/tensorflow_privacy/privacy/dp_query:nested_query", "//third_party/py/tensorflow_privacy/privacy/dp_query:nested_query",

View file

@ -13,6 +13,10 @@
# limitations under the License. # limitations under the License.
"""TensorFlow Privacy library.""" """TensorFlow Privacy library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys import sys
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
@ -42,8 +46,11 @@ else:
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
try:
from privacy.bolt_on.models import BoltOnModel from privacy.bolt_on.models import BoltOnModel
from privacy.bolt_on.optimizers import BoltOn from privacy.bolt_on.optimizers import BoltOn
from privacy.bolt_on.losses import StrongConvexMixin from privacy.bolt_on.losses import StrongConvexMixin
from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
from privacy.bolt_on.losses import StrongConvexHuber from privacy.bolt_on.losses import StrongConvexHuber
except ImportError:
print('module `bolt_on` was not found in this version of TF Privacy')

View file

@ -61,6 +61,10 @@ flags.mark_flag_as_required('epochs')
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta): def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
"""Compute and print results of DP-SGD analysis.""" """Compute and print results of DP-SGD analysis."""
# compute_rdp requires that sigma be the ratio of the standard deviation of
# the Gaussian noise to the l2-sensitivity of the function to which it is
# added. Hence, sigma here corresponds to the `noise_multiplier` parameter
# in the DP-SGD implementation found in privacy.optimizers.dp_optimizer
rdp = compute_rdp(q, sigma, steps, orders) rdp = compute_rdp(q, sigma, steps, orders)
eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta) eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta)
@ -80,13 +84,10 @@ def main(argv):
del argv # argv is not used. del argv # argv is not used.
q = FLAGS.batch_size / FLAGS.N # q - the sampling ratio. q = FLAGS.batch_size / FLAGS.N # q - the sampling ratio.
if q > 1: if q > 1:
raise app.UsageError('N must be larger than the batch size.') raise app.UsageError('N must be larger than the batch size.')
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
list(range(5, 64)) + [128, 256, 512]) list(range(5, 64)) + [128, 256, 512])
steps = int(math.ceil(FLAGS.epochs * FLAGS.N / FLAGS.batch_size)) steps = int(math.ceil(FLAGS.epochs * FLAGS.N / FLAGS.batch_size))
apply_dp_sgd_analysis(q, FLAGS.noise_multiplier, steps, orders, FLAGS.delta) apply_dp_sgd_analysis(q, FLAGS.noise_multiplier, steps, orders, FLAGS.delta)

View file

@ -226,9 +226,9 @@ class QueryWithLedger(dp_query.DPQuery):
"""See base class.""" """See base class."""
return self._query.derive_sample_params(global_state) return self._query.derive_sample_params(global_state)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
return self._query.initial_sample_state(global_state, template) return self._query.initial_sample_state(template)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):
"""See base class.""" """See base class."""

View file

@ -26,6 +26,7 @@ py_test(
name = "gaussian_query_test", name = "gaussian_query_test",
size = "small", size = "small",
srcs = ["gaussian_query_test.py"], srcs = ["gaussian_query_test.py"],
python_version = "PY2",
deps = [ deps = [
":gaussian_query", ":gaussian_query",
":test_utils", ":test_utils",
@ -50,6 +51,7 @@ py_test(
name = "no_privacy_query_test", name = "no_privacy_query_test",
size = "small", size = "small",
srcs = ["no_privacy_query_test.py"], srcs = ["no_privacy_query_test.py"],
python_version = "PY2",
deps = [ deps = [
":no_privacy_query", ":no_privacy_query",
":test_utils", ":test_utils",
@ -72,6 +74,7 @@ py_test(
name = "normalized_query_test", name = "normalized_query_test",
size = "small", size = "small",
srcs = ["normalized_query_test.py"], srcs = ["normalized_query_test.py"],
python_version = "PY2",
deps = [ deps = [
":gaussian_query", ":gaussian_query",
":normalized_query", ":normalized_query",
@ -94,6 +97,7 @@ py_test(
name = "nested_query_test", name = "nested_query_test",
size = "small", size = "small",
srcs = ["nested_query_test.py"], srcs = ["nested_query_test.py"],
python_version = "PY2",
deps = [ deps = [
":gaussian_query", ":gaussian_query",
":nested_query", ":nested_query",
@ -119,6 +123,7 @@ py_library(
py_test( py_test(
name = "quantile_adaptive_clip_sum_query_test", name = "quantile_adaptive_clip_sum_query_test",
srcs = ["quantile_adaptive_clip_sum_query_test.py"], srcs = ["quantile_adaptive_clip_sum_query_test.py"],
python_version = "PY2",
deps = [ deps = [
":quantile_adaptive_clip_sum_query", ":quantile_adaptive_clip_sum_query",
":test_utils", ":test_utils",

View file

@ -88,11 +88,10 @@ class DPQuery(object):
return () return ()
@abc.abstractmethod @abc.abstractmethod
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""Returns an initial state to use for the next sample. """Returns an initial state to use for the next sample.
Args: Args:
global_state: The current global state.
template: A nested structure of tensors, TensorSpecs, or numpy arrays used template: A nested structure of tensors, TensorSpecs, or numpy arrays used
as a template to create the initial sample state. It is assumed that the as a template to create the initial sample state. It is assumed that the
leaves of the structure are python scalars or some type that has leaves of the structure are python scalars or some type that has
@ -216,8 +215,7 @@ def zeros_like(arg):
class SumAggregationDPQuery(DPQuery): class SumAggregationDPQuery(DPQuery):
"""Base class for DPQueries that aggregate via sum.""" """Base class for DPQueries that aggregate via sum."""
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
del global_state # unused.
return nest.map_structure(zeros_like, template) return nest.map_structure(zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record): def accumulate_preprocessed_record(self, sample_state, preprocessed_record):

View file

@ -69,7 +69,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def derive_sample_params(self, global_state): def derive_sample_params(self, global_state):
return global_state.l2_norm_clip return global_state.l2_norm_clip
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
return nest.map_structure( return nest.map_structure(
dp_query.zeros_like, template) dp_query.zeros_like, template)

View file

@ -99,7 +99,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0) query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, records[0]) sample_state = query.initial_sample_state(records[0])
for record in records: for record in records:
sample_state = query.accumulate_record(params, sample_state, record) sample_state = query.accumulate_record(params, sample_state, record)
return sample_state return sample_state

View file

@ -73,9 +73,9 @@ class NestedQuery(dp_query.DPQuery):
"""See base class.""" """See base class."""
return self._map_to_queries('derive_sample_params', global_state) return self._map_to_queries('derive_sample_params', global_state)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
return self._map_to_queries('initial_sample_state', global_state, template) return self._map_to_queries('initial_sample_state', template)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):
"""See base class.""" """See base class."""

View file

@ -45,11 +45,9 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
Accumulates vectors and normalizes by the total number of accumulated vectors. Accumulates vectors and normalizes by the total number of accumulated vectors.
""" """
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
return ( return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
super(NoPrivacyAverageQuery, self).initial_sample_state(
global_state, template),
tf.constant(0.0)) tf.constant(0.0))
def preprocess_record(self, params, record, weight=1): def preprocess_record(self, params, record, weight=1):

View file

@ -68,11 +68,10 @@ class NormalizedQuery(dp_query.DPQuery):
"""See base class.""" """See base class."""
return self._numerator.derive_sample_params(global_state.numerator_state) return self._numerator.derive_sample_params(global_state.numerator_state)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
# NormalizedQuery has no sample state beyond the numerator state. # NormalizedQuery has no sample state beyond the numerator state.
return self._numerator.initial_sample_state( return self._numerator.initial_sample_state(template)
global_state.numerator_state, template)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):
return self._numerator.preprocess_record(params, record) return self._numerator.preprocess_record(params, record)

View file

@ -26,6 +26,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
from distutils.version import LooseVersion
import tensorflow as tf import tensorflow as tf
@ -33,7 +34,10 @@ from privacy.dp_query import dp_query
from privacy.dp_query import gaussian_query from privacy.dp_query import gaussian_query
from privacy.dp_query import normalized_query from privacy.dp_query import normalized_query
nest = tf.contrib.framework.nest if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class QuantileAdaptiveClipSumQuery(dp_query.DPQuery): class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
@ -144,12 +148,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
global_state.clipped_fraction_state) global_state.clipped_fraction_state)
return self._SampleParams(sum_params, clipped_fraction_params) return self._SampleParams(sum_params, clipped_fraction_params)
def initial_sample_state(self, global_state, template): def initial_sample_state(self, template):
"""See base class.""" """See base class."""
sum_state = self._sum_query.initial_sample_state( sum_state = self._sum_query.initial_sample_state(template)
global_state.sum_state, template)
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state( clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
global_state.clipped_fraction_state, tf.constant(0.0)) tf.constant(0.0))
return self._SampleState(sum_state, clipped_fraction_state) return self._SampleState(sum_state, clipped_fraction_state)
def preprocess_record(self, params, record): def preprocess_record(self, params, record):

View file

@ -38,7 +38,7 @@ def run_query(query, records, global_state=None, weights=None):
if not global_state: if not global_state:
global_state = query.initial_global_state() global_state = query.initial_global_state()
params = query.derive_sample_params(global_state) params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(global_state, next(iter(records))) sample_state = query.initial_sample_state(next(iter(records)))
if weights is None: if weights is None:
for record in records: for record in records:
sample_state = query.accumulate_record(params, sample_state, record) sample_state = query.accumulate_record(params, sample_state, record)

View file

@ -93,10 +93,7 @@ def make_optimizer_class(cls):
vector_loss = loss() vector_loss = loss()
if self._num_microbatches is None: if self._num_microbatches is None:
self._num_microbatches = tf.shape(vector_loss)[0] self._num_microbatches = tf.shape(vector_loss)[0]
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger): sample_state = self._dp_sum_query.initial_sample_state(var_list)
self._dp_sum_query.set_batch_size(self._num_microbatches)
sample_state = self._dp_sum_query.initial_sample_state(
self._global_state, var_list)
microbatches_losses = tf.reshape(vector_loss, microbatches_losses = tf.reshape(vector_loss,
[self._num_microbatches, -1]) [self._num_microbatches, -1])
sample_params = ( sample_params = (
@ -136,8 +133,6 @@ def make_optimizer_class(cls):
# sampling from the dataset without replacement. # sampling from the dataset without replacement.
if self._num_microbatches is None: if self._num_microbatches is None:
self._num_microbatches = tf.shape(loss)[0] self._num_microbatches = tf.shape(loss)[0]
if isinstance(self._dp_sum_query, privacy_ledger.QueryWithLedger):
self._dp_sum_query.set_batch_size(self._num_microbatches)
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
sample_params = ( sample_params = (
@ -162,8 +157,7 @@ def make_optimizer_class(cls):
tf.trainable_variables() + tf.get_collection( tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._dp_sum_query.initial_sample_state( sample_state = self._dp_sum_query.initial_sample_state(var_list)
self._global_state, var_list)
if self._unroll_microbatches: if self._unroll_microbatches:
for idx in range(self._num_microbatches): for idx in range(self._num_microbatches):

View file

@ -20,6 +20,10 @@ Here is a list of all the tutorials included:
* `mnist_dpsgd_tutorial_keras.py`: learn a convolutional neural network on MNIST * `mnist_dpsgd_tutorial_keras.py`: learn a convolutional neural network on MNIST
with differential privacy using tf.Keras. with differential privacy using tf.Keras.
* `mnist_lr_tutorial.py`: learn a differentially private logistic regression
model on MNIST. The model illustrates application of the
"amplification-by-iteration" analysis (https://arxiv.org/abs/1808.06651).
The rest of this README describes the different parameters used to configure The rest of this README describes the different parameters used to configure
DP-SGD as well as expected outputs for the `mnist_dpsgd_tutorial.py` tutorial. DP-SGD as well as expected outputs for the `mnist_dpsgd_tutorial.py` tutorial.

View file

@ -41,10 +41,10 @@ flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1, flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm') 'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 256, 'Batch size') flags.DEFINE_integer('batch_size', 250, 'Batch size')
flags.DEFINE_integer('epochs', 60, 'Number of epochs') flags.DEFINE_integer('epochs', 60, 'Number of epochs')
flags.DEFINE_integer( flags.DEFINE_integer(
'microbatches', 256, 'Number of microbatches ' 'microbatches', 250, 'Number of microbatches '
'(must evenly divide batch_size)') '(must evenly divide batch_size)')
flags.DEFINE_string('model_dir', None, 'Model directory') flags.DEFINE_string('model_dir', None, 'Model directory')
@ -121,9 +121,8 @@ def main(unused_argv):
optimizer = DPGradientDescentGaussianOptimizer( optimizer = DPGradientDescentGaussianOptimizer(
l2_norm_clip=FLAGS.l2_norm_clip, l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier, noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.num_microbatches, num_microbatches=FLAGS.microbatches,
learning_rate=FLAGS.learning_rate, learning_rate=FLAGS.learning_rate)
unroll_microbatches=True)
# Compute vector of per-example loss rather than its mean over a minibatch. # Compute vector of per-example loss rather than its mean over a minibatch.
loss = tf.keras.losses.CategoricalCrossentropy( loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE) from_logits=True, reduction=tf.losses.Reduction.NONE)

View file

@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""DP Logistic Regression on MNIST. """DP Logistic Regression on MNIST.
DP Logistic Regression on MNIST with support for privacy-by-iteration analysis. DP Logistic Regression on MNIST with support for privacy-by-iteration analysis.
Feldman, Vitaly, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta. Vitaly Feldman, Ilya Mironov, Kunal Talwar, and Abhradeep Thakurta.
"Privacy amplification by iteration." "Privacy amplification by iteration."
In 2018 IEEE 59th Annual Symposium on Foundations of Computer Science (FOCS), In 2018 IEEE 59th Annual Symposium on Foundations of Computer Science (FOCS),
pp. 521-532. IEEE, 2018. pp. 521-532. IEEE, 2018.
@ -36,6 +35,8 @@ from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer from privacy.optimizers import dp_optimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
@ -45,32 +46,30 @@ else:
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, ' flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, '
'train with vanilla SGD.') 'train with vanilla SGD.')
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training') flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 0.02, flags.DEFINE_float('noise_multiplier', 0.05,
'Ratio of the standard deviation to the clipping norm') 'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_integer('batch_size', 1, 'Batch size') flags.DEFINE_integer('batch_size', 5, 'Batch size')
flags.DEFINE_integer('epochs', 5, 'Number of epochs') flags.DEFINE_integer('epochs', 5, 'Number of epochs')
flags.DEFINE_integer('microbatches', 1, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_float('regularizer', 0, 'L2 regularizer coefficient') flags.DEFINE_float('regularizer', 0, 'L2 regularizer coefficient')
flags.DEFINE_string('model_dir', None, 'Model directory') flags.DEFINE_string('model_dir', None, 'Model directory')
flags.DEFINE_float('data_l2_norm', 8, flags.DEFINE_float('data_l2_norm', 8, 'Bound on the L2 norm of normalized data')
'Bound on the L2 norm of normalized data.')
def lr_model_fn(features, labels, mode, nclasses, dim): def lr_model_fn(features, labels, mode, nclasses, dim):
"""Model function for logistic regression.""" """Model function for logistic regression."""
input_layer = tf.reshape(features['x'], tuple([-1]) + dim) input_layer = tf.reshape(features['x'], tuple([-1]) + dim)
logits = tf.layers.dense(inputs=input_layer, logits = tf.layers.dense(
inputs=input_layer,
units=nclasses, units=nclasses,
kernel_regularizer=tf.contrib.layers.l2_regularizer( kernel_regularizer=tf.contrib.layers.l2_regularizer(
scale=FLAGS.regularizer), scale=FLAGS.regularizer),
bias_regularizer=tf.contrib.layers.l2_regularizer( bias_regularizer=tf.contrib.layers.l2_regularizer(
scale=FLAGS.regularizer) scale=FLAGS.regularizer))
)
# Calculate loss as a vector (to support microbatches in DP-SGD). # Calculate loss as a vector (to support microbatches in DP-SGD).
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
@ -80,18 +79,15 @@ def lr_model_fn(features, labels, mode, nclasses, dim):
# Configure the training op (for TRAIN mode). # Configure the training op (for TRAIN mode).
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
if FLAGS.dpsgd: if FLAGS.dpsgd:
# Use DP version of GradientDescentOptimizer. Other optimizers are
# available in dp_optimizer. Most optimizers inheriting from
# tf.train.Optimizer should be wrappable in differentially private
# counterparts by calling dp_optimizer.optimizer_from_args().
# The loss function is L-Lipschitz with L = sqrt(2*(||x||^2 + 1)) where # The loss function is L-Lipschitz with L = sqrt(2*(||x||^2 + 1)) where
# ||x|| is the norm of the data. # ||x|| is the norm of the data.
# We don't use microbatches (thus speeding up computation), since no
# clipping is necessary due to data normalization.
optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer( optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer(
l2_norm_clip=math.sqrt(2*(FLAGS.data_l2_norm**2 + 1)), l2_norm_clip=math.sqrt(2 * (FLAGS.data_l2_norm**2 + 1)),
noise_multiplier=FLAGS.noise_multiplier, noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches, num_microbatches=1,
learning_rate=FLAGS.learning_rate) learning_rate=FLAGS.learning_rate)
opt_loss = vector_loss opt_loss = vector_loss
else: else:
@ -103,21 +99,18 @@ def lr_model_fn(features, labels, mode, nclasses, dim):
# the vector_loss because tf.estimator requires a scalar loss. This is only # the vector_loss because tf.estimator requires a scalar loss. This is only
# used for evaluation and debugging by tf.estimator. The actual loss being # used for evaluation and debugging by tf.estimator. The actual loss being
# minimized is opt_loss defined above and passed to optimizer.minimize(). # minimized is opt_loss defined above and passed to optimizer.minimize().
return tf.estimator.EstimatorSpec(mode=mode, return tf.estimator.EstimatorSpec(
loss=scalar_loss, mode=mode, loss=scalar_loss, train_op=train_op)
train_op=train_op)
# Add evaluation metrics (for EVAL mode). # Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL: elif mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = { eval_metric_ops = {
'accuracy': 'accuracy':
tf.metrics.accuracy( tf.metrics.accuracy(
labels=labels, labels=labels, predictions=tf.argmax(input=logits, axis=1))
predictions=tf.argmax(input=logits, axis=1))
} }
return tf.estimator.EstimatorSpec(mode=mode, return tf.estimator.EstimatorSpec(
loss=scalar_loss, mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops)
eval_metric_ops=eval_metric_ops)
def normalize_data(data, data_l2_norm): def normalize_data(data, data_l2_norm):
@ -159,14 +152,50 @@ def load_mnist(data_l2_norm=float('inf')):
return train_data, train_labels, test_data, test_labels return train_data, train_labels, test_data, test_labels
def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier):
"""Tabulating position-dependent privacy guarantees."""
if noise_multiplier == 0:
print('No differential privacy (additive noise is 0).')
return
print('In the conditions of Theorem 34 (https://arxiv.org/abs/1808.06651) '
'the training procedure results in the following privacy guarantees.')
print('Out of the total of {} samples:'.format(samples))
steps_per_epoch = samples // batch_size
orders = np.concatenate(
[np.linspace(2, 20, num=181),
np.linspace(20, 100, num=81)])
delta = 1e-5
for p in (.5, .9, .99):
steps = math.ceil(steps_per_epoch * p) # Steps in the last epoch.
coef = 2 * (noise_multiplier * batch_size)**-2 * (
# Accounting for privacy loss
(epochs - 1) / steps_per_epoch + # ... from all-but-last epochs
1 / (steps_per_epoch - steps + 1)) # ... due to the last epoch
# Using RDP accountant to compute eps. Doing computation analytically is
# an option.
rdp = [order * coef for order in orders]
eps, _, _ = get_privacy_spent(orders, rdp, target_delta=delta)
print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format(
p * 100, eps, delta))
# Compute privacy guarantees for the Sampled Gaussian Mechanism.
rdp_sgm = compute_rdp(batch_size / samples, noise_multiplier,
epochs * steps_per_epoch, orders)
eps_sgm, _, _ = get_privacy_spent(orders, rdp_sgm, target_delta=delta)
print('By comparison, DP-SGD analysis for training done with the same '
'parameters and random shuffling in each epoch guarantees '
'({:.2f}, {})-DP for all samples.'.format(eps_sgm, delta))
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
raise ValueError('Number of microbatches should divide evenly batch_size')
if FLAGS.data_l2_norm <= 0: if FLAGS.data_l2_norm <= 0:
raise ValueError('FLAGS.data_l2_norm needs to be positive.') raise ValueError('data_l2_norm must be positive.')
if FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2: if FLAGS.dpsgd and FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2:
raise ValueError('The amplification by iteration analysis requires' raise ValueError('The amplification-by-iteration analysis requires'
'learning_rate <= 2 / beta, where beta is the smoothness' 'learning_rate <= 2 / beta, where beta is the smoothness'
'of the loss function and is upper bounded by ||x||^2 / 4' 'of the loss function and is upper bounded by ||x||^2 / 4'
'with ||x|| being the largest L2 norm of the samples.') 'with ||x|| being the largest L2 norm of the samples.')
@ -178,15 +207,12 @@ def main(unused_argv):
train_data, train_labels, test_data, test_labels = load_mnist( train_data, train_labels, test_data, test_labels = load_mnist(
data_l2_norm=FLAGS.data_l2_norm) data_l2_norm=FLAGS.data_l2_norm)
# Instantiate the tf.Estimator. # Instantiate tf.Estimator.
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
model_fn = lambda features, labels, mode: lr_model_fn(features, labels, mode, model_fn = lambda features, labels, mode: lr_model_fn(
nclasses=10, features, labels, mode, nclasses=10, dim=train_data.shape[1:])
dim=train_data.shape[1:]
)
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=model_fn, model_fn=model_fn, model_dir=FLAGS.model_dir)
model_dir=FLAGS.model_dir)
# Create tf.Estimator input functions for the training and test data. # Create tf.Estimator input functions for the training and test data.
# To analyze the per-user privacy loss, we keep the same orders of samples in # To analyze the per-user privacy loss, we keep the same orders of samples in
@ -198,22 +224,27 @@ def main(unused_argv):
num_epochs=FLAGS.epochs, num_epochs=FLAGS.epochs,
shuffle=False) shuffle=False)
eval_input_fn = tf.estimator.inputs.numpy_input_fn( eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': test_data}, x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False)
y=test_labels,
num_epochs=1,
shuffle=False)
# Train the model # Train the model.
steps_per_epoch = train_data.shape[0] // FLAGS.batch_size num_samples = train_data.shape[0]
mnist_classifier.train(input_fn=train_input_fn, steps_per_epoch = num_samples // FLAGS.batch_size
steps=steps_per_epoch * FLAGS.epochs)
# Evaluate the model and print results mnist_classifier.train(
input_fn=train_input_fn, steps=steps_per_epoch * FLAGS.epochs)
# Evaluate the model and print results.
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
test_accuracy = eval_results['accuracy'] print('Test accuracy after {} epochs is: {:.2f}'.format(
print('Test accuracy after %d epochs is: %.3f' % (FLAGS.epochs, FLAGS.epochs, eval_results['accuracy']))
test_accuracy))
if FLAGS.dpsgd:
print_privacy_guarantees(
epochs=FLAGS.epochs,
batch_size=FLAGS.batch_size,
samples=num_samples,
noise_multiplier=FLAGS.noise_multiplier,
)
if __name__ == '__main__': if __name__ == '__main__':
app.run(main) app.run(main)