Merge pull request #7 from tensorflow/master

Updating to match tensorflow master
This commit is contained in:
Christopher Choquette Choo 2019-08-21 14:15:41 -04:00 committed by GitHub
commit 6d3776e4a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 565 additions and 1 deletions

View file

@ -227,8 +227,8 @@ def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False):
n_samples: number of rows
input_dim: input dimensionality
n_classes: output dimensionality
batch_size: The desired batch_size
generator: False for array, True for generator
batch_size: The desired batch_size.
Returns:
X as (n_samples, input_dim), Y as (n_samples, n_outputs)

View file

@ -0,0 +1,153 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vectorized differentially private optimizers for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
AdagradOptimizer = tf.train.AdagradOptimizer
AdamOptimizer = tf.train.AdamOptimizer
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
parent_code = tf.train.Optimizer.compute_gradients.__code__
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
else:
nest = tf.nest
AdagradOptimizer = tf.optimizers.Adagrad
AdamOptimizer = tf.optimizers.Adam
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
GATE_OP = None # pylint: disable=invalid-name
def make_vectorized_optimizer_class(cls):
"""Constructs a vectorized DP optimizer class from an existing one."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
child_code = cls.compute_gradients.__code__
else:
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
if child_code is not parent_code:
tf.logging.warning(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
cls.__name__)
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls."""
def __init__(
self,
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initialize the DPOptimizerClass.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients)
noise_multiplier: Ratio of the standard deviation to the clipping norm
num_microbatches: How many microbatches into which the minibatch is
split. If None, will default to the size of the minibatch, and
per-example gradients will be computed.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches
def compute_gradients(self,
loss,
var_list,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None,
gradient_tape=None):
if callable(loss):
# TF is running in Eager mode
raise NotImplementedError('Vectorized optimizer unavailable for TF2.')
else:
# TF is running in graph mode, check we did not receive a gradient tape.
if gradient_tape:
raise ValueError('When in graph mode, a tape should not be passed.')
batch_size = tf.shape(loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be
# sampling from the dataset without replacement.
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
if var_list is None:
var_list = (
tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
def process_microbatch(microbatch_loss):
"""Compute clipped grads for one microbatch."""
microbatch_loss = tf.reduce_mean(microbatch_loss)
grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients(
microbatch_loss,
var_list,
gate_gradients,
aggregation_method,
colocate_gradients_with_ops,
grad_loss))
grads_list = [
g if g is not None else tf.zeros_like(v)
for (g, v) in zip(list(grads), var_list)
]
# Clip gradients to have L2 norm of l2_norm_clip.
# Here, we use TF primitives rather than the built-in
# tf.clip_by_global_norm() so that operations can be vectorized
# across microbatches.
grads_flat = nest.flatten(grads_list)
squared_l2_norms = [tf.reduce_sum(tf.square(g)) for g in grads_flat]
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
clipped_flat = [g / div for g in grads_flat]
clipped_grads = nest.pack_sequence_as(grads_list, clipped_flat)
return clipped_grads
clipped_grads = tf.vectorized_map(process_microbatch, microbatch_losses)
def reduce_noise_normalize_batch(stacked_grads):
summed_grads = tf.reduce_sum(stacked_grads, axis=0)
noise_stddev = self._l2_norm_clip * self._noise_multiplier
noise = tf.random.normal(tf.shape(summed_grads),
stddev=noise_stddev)
noised_grads = summed_grads + noise
return noised_grads / tf.cast(self._num_microbatches, tf.float32)
final_grads = nest.map_structure(reduce_noise_normalize_batch,
clipped_grads)
return list(zip(final_grads, var_list))
return DPOptimizerClass
VectorizedDPAdagrad = make_vectorized_optimizer_class(AdagradOptimizer)
VectorizedDPAdam = make_vectorized_optimizer_class(AdamOptimizer)
VectorizedDPSGD = make_vectorized_optimizer_class(GradientDescentOptimizer)

View file

@ -0,0 +1,204 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for differentially private optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import mock
import numpy as np
import tensorflow as tf
from privacy.optimizers import dp_optimizer_vectorized
from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad
from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam
from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _loss(self, val0, val1):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
# Parameters for testing: optimizer, num_microbatches, expected answer.
@parameterized.named_parameters(
('DPGradientDescent 1', VectorizedDPSGD, 1, [-2.5, -2.5]),
('DPGradientDescent 2', VectorizedDPSGD, 2, [-2.5, -2.5]),
('DPGradientDescent 4', VectorizedDPSGD, 4, [-2.5, -2.5]),
('DPAdagrad 1', VectorizedDPAdagrad, 1, [-2.5, -2.5]),
('DPAdagrad 2', VectorizedDPAdagrad, 2, [-2.5, -2.5]),
('DPAdagrad 4', VectorizedDPAdagrad, 4, [-2.5, -2.5]),
('DPAdam 1', VectorizedDPAdam, 1, [-2.5, -2.5]),
('DPAdam 2', VectorizedDPAdam, 2, [-2.5, -2.5]),
('DPAdam 4', VectorizedDPAdam, 4, [-2.5, -2.5]))
def testBaseline(self, cls, num_microbatches, expected_answer):
with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
opt = cls(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
# Expected gradient is sum of differences divided by number of
# microbatches.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', VectorizedDPSGD),
('DPAdagrad', VectorizedDPAdagrad),
('DPAdam', VectorizedDPAdam))
def testClippingNorm(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
opt = cls(l2_norm_clip=1.0,
noise_multiplier=0.,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
# Expected gradient is sum of differences.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', VectorizedDPSGD),
('DPAdagrad', VectorizedDPAdagrad),
('DPAdam', VectorizedDPAdam))
def testNoiseMultiplier(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
opt = cls(l2_norm_clip=4.0,
noise_multiplier=8.0,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads = []
for _ in range(1000):
grads_and_vars = sess.run(gradient_op)
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
@mock.patch.object(tf, 'logging')
def testComputeGradientsOverrideWarning(self, mock_logging):
class SimpleOptimizer(tf.train.Optimizer):
def compute_gradients(self):
return 0
dp_optimizer_vectorized.make_vectorized_optimizer_class(SimpleOptimizer)
mock_logging.warning.assert_called_once_with(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
'SimpleOptimizer')
def testEstimator(self):
"""Tests that DP optimizers work with tf.estimator."""
def linear_model_fn(features, labels, mode):
preds = tf.keras.layers.Dense(
1, activation='linear', name='dense').apply(features['x'])
vector_loss = tf.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(vector_loss)
optimizer = VectorizedDPSGD(
l2_norm_clip=1.0,
noise_multiplier=0.,
num_microbatches=1,
learning_rate=1.0)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, train_op=train_op)
linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn)
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = 6.0
train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.1, size=(200, 1)).astype(np.float32)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': train_data},
y=train_labels,
batch_size=20,
num_epochs=10,
shuffle=True)
linear_regressor.train(input_fn=train_input_fn, steps=100)
self.assertAllClose(
linear_regressor.get_variable_value('dense/kernel'),
true_weights,
atol=1.0)
@parameterized.named_parameters(
('DPGradientDescent', VectorizedDPSGD),
('DPAdagrad', VectorizedDPAdagrad),
('DPAdam', VectorizedDPAdam))
def testDPGaussianOptimizerClass(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
opt = cls(
l2_norm_clip=4.0,
noise_multiplier=2.0,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads = []
for _ in range(1000):
grads_and_vars = sess.run(gradient_op)
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,207 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training a CNN on MNIST with vectorized DP-SGD optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
from distutils.version import LooseVersion
import numpy as np
import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer_vectorized
flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, '
'train with vanilla SGD.')
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 200, 'Batch size')
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
flags.DEFINE_integer(
'microbatches', 200, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_string('model_dir', None, 'Model directory')
FLAGS = flags.FLAGS
NUM_TRAIN_EXAMPLES = 60000
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
else:
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
def compute_epsilon(steps):
"""Computes epsilon value for given hyperparameters."""
if FLAGS.noise_multiplier == 0.0:
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / NUM_TRAIN_EXAMPLES
rdp = compute_rdp(q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to approximate 1 / (number of training points).
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
def cnn_model_fn(features, labels, mode):
"""Model function for a CNN."""
# Define CNN architecture using tf.keras.layers.
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
y = tf.keras.layers.Conv2D(16, 8,
strides=2,
padding='same',
activation='relu').apply(input_layer)
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
y = tf.keras.layers.Conv2D(32, 4,
strides=2,
padding='valid',
activation='relu').apply(y)
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
y = tf.keras.layers.Flatten().apply(y)
y = tf.keras.layers.Dense(32, activation='relu').apply(y)
logits = tf.keras.layers.Dense(10).apply(y)
# Calculate loss as a vector (to support microbatches in DP-SGD).
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
# Define mean of loss across minibatch (for reporting through tf.Estimator).
scalar_loss = tf.reduce_mean(vector_loss)
# Configure the training op (for TRAIN mode).
if mode == tf.estimator.ModeKeys.TRAIN:
if FLAGS.dpsgd:
# Use DP version of GradientDescentOptimizer. Other optimizers are
# available in dp_optimizer. Most optimizers inheriting from
# tf.train.Optimizer should be wrappable in differentially private
# counterparts by calling dp_optimizer.optimizer_from_args().
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
learning_rate=FLAGS.learning_rate)
opt_loss = vector_loss
else:
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
opt_loss = scalar_loss
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
# In the following, we pass the mean of the loss (scalar_loss) rather than
# the vector_loss because tf.estimator requires a scalar loss. This is only
# used for evaluation and debugging by tf.estimator. The actual loss being
# minimized is opt_loss defined above and passed to optimizer.minimize().
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
train_op=train_op)
# Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = {
'accuracy':
tf.metrics.accuracy(
labels=labels,
predictions=tf.argmax(input=logits, axis=1))
}
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
eval_metric_ops=eval_metric_ops)
def load_mnist():
"""Loads MNIST and preprocesses to combine training and validation data."""
train, test = tf.keras.datasets.mnist.load_data()
train_data, train_labels = train
test_data, test_labels = test
train_data = np.array(train_data, dtype=np.float32) / 255
test_data = np.array(test_data, dtype=np.float32) / 255
train_labels = np.array(train_labels, dtype=np.int32)
test_labels = np.array(test_labels, dtype=np.int32)
assert train_data.min() == 0.
assert train_data.max() == 1.
assert test_data.min() == 0.
assert test_data.max() == 1.
assert train_labels.ndim == 1
assert test_labels.ndim == 1
return train_data, train_labels, test_data, test_labels
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
raise ValueError('Number of microbatches should divide evenly batch_size')
# Load training and test data.
train_data, train_labels, test_data, test_labels = load_mnist()
# Instantiate the tf.Estimator.
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
model_dir=FLAGS.model_dir)
# Create tf.Estimator input functions for the training and test data.
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': train_data},
y=train_labels,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.epochs,
shuffle=True)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': test_data},
y=test_labels,
num_epochs=1,
shuffle=False)
# Training loop.
steps_per_epoch = NUM_TRAIN_EXAMPLES // FLAGS.batch_size
for epoch in range(1, FLAGS.epochs + 1):
# Train the model for one epoch.
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch)
# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
test_accuracy = eval_results['accuracy']
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
# Compute the privacy budget expended.
if FLAGS.dpsgd:
eps = compute_epsilon(epoch * NUM_TRAIN_EXAMPLES // FLAGS.batch_size)
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
else:
print('Trained with vanilla non-private SGD optimizer')
if __name__ == '__main__':
app.run(main)