diff --git a/privacy/bolt_on/models_test.py b/privacy/bolt_on/models_test.py index 580255a..772f792 100644 --- a/privacy/bolt_on/models_test.py +++ b/privacy/bolt_on/models_test.py @@ -227,8 +227,8 @@ def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False): n_samples: number of rows input_dim: input dimensionality n_classes: output dimensionality + batch_size: The desired batch_size generator: False for array, True for generator - batch_size: The desired batch_size. Returns: X as (n_samples, input_dim), Y as (n_samples, n_outputs) diff --git a/privacy/optimizers/dp_optimizer_vectorized.py b/privacy/optimizers/dp_optimizer_vectorized.py new file mode 100644 index 0000000..7295e1d --- /dev/null +++ b/privacy/optimizers/dp_optimizer_vectorized.py @@ -0,0 +1,153 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vectorized differentially private optimizers for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from distutils.version import LooseVersion +import tensorflow as tf + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + nest = tf.contrib.framework.nest + AdagradOptimizer = tf.train.AdagradOptimizer + AdamOptimizer = tf.train.AdamOptimizer + GradientDescentOptimizer = tf.train.GradientDescentOptimizer + parent_code = tf.train.Optimizer.compute_gradients.__code__ + GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name +else: + nest = tf.nest + AdagradOptimizer = tf.optimizers.Adagrad + AdamOptimizer = tf.optimizers.Adam + GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name + parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access + GATE_OP = None # pylint: disable=invalid-name + + +def make_vectorized_optimizer_class(cls): + """Constructs a vectorized DP optimizer class from an existing one.""" + if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + child_code = cls.compute_gradients.__code__ + else: + child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access + if child_code is not parent_code: + tf.logging.warning( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + cls.__name__) + + class DPOptimizerClass(cls): + """Differentially private subclass of given class cls.""" + + def __init__( + self, + l2_norm_clip, + noise_multiplier, + num_microbatches=None, + *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args + **kwargs): + """Initialize the DPOptimizerClass. + + Args: + l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients) + noise_multiplier: Ratio of the standard deviation to the clipping norm + num_microbatches: How many microbatches into which the minibatch is + split. If None, will default to the size of the minibatch, and + per-example gradients will be computed. + """ + super(DPOptimizerClass, self).__init__(*args, **kwargs) + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + self._num_microbatches = num_microbatches + + def compute_gradients(self, + loss, + var_list, + gate_gradients=GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None, + gradient_tape=None): + if callable(loss): + # TF is running in Eager mode + raise NotImplementedError('Vectorized optimizer unavailable for TF2.') + else: + # TF is running in graph mode, check we did not receive a gradient tape. + if gradient_tape: + raise ValueError('When in graph mode, a tape should not be passed.') + + batch_size = tf.shape(loss)[0] + if self._num_microbatches is None: + self._num_microbatches = batch_size + + # Note: it would be closer to the correct i.i.d. sampling of records if + # we sampled each microbatch from the appropriate binomial distribution, + # although that still wouldn't be quite correct because it would be + # sampling from the dataset without replacement. + microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1]) + + if var_list is None: + var_list = ( + tf.trainable_variables() + tf.get_collection( + tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + + def process_microbatch(microbatch_loss): + """Compute clipped grads for one microbatch.""" + microbatch_loss = tf.reduce_mean(microbatch_loss) + grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients( + microbatch_loss, + var_list, + gate_gradients, + aggregation_method, + colocate_gradients_with_ops, + grad_loss)) + grads_list = [ + g if g is not None else tf.zeros_like(v) + for (g, v) in zip(list(grads), var_list) + ] + # Clip gradients to have L2 norm of l2_norm_clip. + # Here, we use TF primitives rather than the built-in + # tf.clip_by_global_norm() so that operations can be vectorized + # across microbatches. + grads_flat = nest.flatten(grads_list) + squared_l2_norms = [tf.reduce_sum(tf.square(g)) for g in grads_flat] + global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) + div = tf.maximum(global_norm / self._l2_norm_clip, 1.) + clipped_flat = [g / div for g in grads_flat] + clipped_grads = nest.pack_sequence_as(grads_list, clipped_flat) + return clipped_grads + + clipped_grads = tf.vectorized_map(process_microbatch, microbatch_losses) + + def reduce_noise_normalize_batch(stacked_grads): + summed_grads = tf.reduce_sum(stacked_grads, axis=0) + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal(tf.shape(summed_grads), + stddev=noise_stddev) + noised_grads = summed_grads + noise + return noised_grads / tf.cast(self._num_microbatches, tf.float32) + + final_grads = nest.map_structure(reduce_noise_normalize_batch, + clipped_grads) + + return list(zip(final_grads, var_list)) + + return DPOptimizerClass + + +VectorizedDPAdagrad = make_vectorized_optimizer_class(AdagradOptimizer) +VectorizedDPAdam = make_vectorized_optimizer_class(AdamOptimizer) +VectorizedDPSGD = make_vectorized_optimizer_class(GradientDescentOptimizer) diff --git a/privacy/optimizers/dp_optimizer_vectorized_test.py b/privacy/optimizers/dp_optimizer_vectorized_test.py new file mode 100644 index 0000000..e5afd1d --- /dev/null +++ b/privacy/optimizers/dp_optimizer_vectorized_test.py @@ -0,0 +1,204 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for differentially private optimizers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import mock +import numpy as np +import tensorflow as tf + +from privacy.optimizers import dp_optimizer_vectorized +from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad +from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam +from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD + + +class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + + def _loss(self, val0, val1): + """Loss function that is minimized at the mean of the input points.""" + return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1) + + # Parameters for testing: optimizer, num_microbatches, expected answer. + @parameterized.named_parameters( + ('DPGradientDescent 1', VectorizedDPSGD, 1, [-2.5, -2.5]), + ('DPGradientDescent 2', VectorizedDPSGD, 2, [-2.5, -2.5]), + ('DPGradientDescent 4', VectorizedDPSGD, 4, [-2.5, -2.5]), + ('DPAdagrad 1', VectorizedDPAdagrad, 1, [-2.5, -2.5]), + ('DPAdagrad 2', VectorizedDPAdagrad, 2, [-2.5, -2.5]), + ('DPAdagrad 4', VectorizedDPAdagrad, 4, [-2.5, -2.5]), + ('DPAdam 1', VectorizedDPAdam, 1, [-2.5, -2.5]), + ('DPAdam 2', VectorizedDPAdam, 2, [-2.5, -2.5]), + ('DPAdam 4', VectorizedDPAdam, 4, [-2.5, -2.5])) + def testBaseline(self, cls, num_microbatches, expected_answer): + with self.cached_session() as sess: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + + opt = cls( + l2_norm_clip=1.0e9, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + # Expected gradient is sum of differences divided by number of + # microbatches. + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', VectorizedDPSGD), + ('DPAdagrad', VectorizedDPAdagrad), + ('DPAdam', VectorizedDPAdam)) + def testClippingNorm(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0, 0.0]) + data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) + + opt = cls(l2_norm_clip=1.0, + noise_multiplier=0., + num_microbatches=1, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + + # Expected gradient is sum of differences. + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', VectorizedDPSGD), + ('DPAdagrad', VectorizedDPAdagrad), + ('DPAdam', VectorizedDPAdam)) + def testNoiseMultiplier(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + opt = cls(l2_norm_clip=4.0, + noise_multiplier=8.0, + num_microbatches=1, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads = [] + for _ in range(1000): + grads_and_vars = sess.run(gradient_op) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 4.0 * 8.0, 0.5) + + @mock.patch.object(tf, 'logging') + def testComputeGradientsOverrideWarning(self, mock_logging): + + class SimpleOptimizer(tf.train.Optimizer): + + def compute_gradients(self): + return 0 + + dp_optimizer_vectorized.make_vectorized_optimizer_class(SimpleOptimizer) + mock_logging.warning.assert_called_once_with( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + 'SimpleOptimizer') + + def testEstimator(self): + """Tests that DP optimizers work with tf.estimator.""" + + def linear_model_fn(features, labels, mode): + preds = tf.keras.layers.Dense( + 1, activation='linear', name='dense').apply(features['x']) + + vector_loss = tf.squared_difference(labels, preds) + scalar_loss = tf.reduce_mean(vector_loss) + optimizer = VectorizedDPSGD( + l2_norm_clip=1.0, + noise_multiplier=0., + num_microbatches=1, + learning_rate=1.0) + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss=vector_loss, global_step=global_step) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) + + linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn) + true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32) + true_bias = 6.0 + train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32) + + train_labels = np.matmul(train_data, + true_weights) + true_bias + np.random.normal( + scale=0.1, size=(200, 1)).astype(np.float32) + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': train_data}, + y=train_labels, + batch_size=20, + num_epochs=10, + shuffle=True) + linear_regressor.train(input_fn=train_input_fn, steps=100) + self.assertAllClose( + linear_regressor.get_variable_value('dense/kernel'), + true_weights, + atol=1.0) + + @parameterized.named_parameters( + ('DPGradientDescent', VectorizedDPSGD), + ('DPAdagrad', VectorizedDPAdagrad), + ('DPAdam', VectorizedDPAdam)) + def testDPGaussianOptimizerClass(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + opt = cls( + l2_norm_clip=4.0, + noise_multiplier=2.0, + num_microbatches=1, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads = [] + for _ in range(1000): + grads_and_vars = sess.run(gradient_op) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tutorials/mnist_dpsgd_tutorial_vectorized.py b/tutorials/mnist_dpsgd_tutorial_vectorized.py new file mode 100644 index 0000000..2b78f82 --- /dev/null +++ b/tutorials/mnist_dpsgd_tutorial_vectorized.py @@ -0,0 +1,207 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training a CNN on MNIST with vectorized DP-SGD optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags + +from distutils.version import LooseVersion + +import numpy as np +import tensorflow as tf + +from privacy.analysis.rdp_accountant import compute_rdp +from privacy.analysis.rdp_accountant import get_privacy_spent +from privacy.optimizers import dp_optimizer_vectorized + + +flags.DEFINE_boolean( + 'dpsgd', True, 'If True, train with DP-SGD. If False, ' + 'train with vanilla SGD.') +flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') +flags.DEFINE_float('noise_multiplier', 1.1, + 'Ratio of the standard deviation to the clipping norm') +flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') +flags.DEFINE_integer('batch_size', 200, 'Batch size') +flags.DEFINE_integer('epochs', 60, 'Number of epochs') +flags.DEFINE_integer( + 'microbatches', 200, 'Number of microbatches ' + '(must evenly divide batch_size)') +flags.DEFINE_string('model_dir', None, 'Model directory') + + +FLAGS = flags.FLAGS + + +NUM_TRAIN_EXAMPLES = 60000 + + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + GradientDescentOptimizer = tf.train.GradientDescentOptimizer +else: + GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name + + +def compute_epsilon(steps): + """Computes epsilon value for given hyperparameters.""" + if FLAGS.noise_multiplier == 0.0: + return float('inf') + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + sampling_probability = FLAGS.batch_size / NUM_TRAIN_EXAMPLES + rdp = compute_rdp(q=sampling_probability, + noise_multiplier=FLAGS.noise_multiplier, + steps=steps, + orders=orders) + # Delta is set to approximate 1 / (number of training points). + return get_privacy_spent(orders, rdp, target_delta=1e-5)[0] + + +def cnn_model_fn(features, labels, mode): + """Model function for a CNN.""" + + # Define CNN architecture using tf.keras.layers. + input_layer = tf.reshape(features['x'], [-1, 28, 28, 1]) + y = tf.keras.layers.Conv2D(16, 8, + strides=2, + padding='same', + activation='relu').apply(input_layer) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Conv2D(32, 4, + strides=2, + padding='valid', + activation='relu').apply(y) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Flatten().apply(y) + y = tf.keras.layers.Dense(32, activation='relu').apply(y) + logits = tf.keras.layers.Dense(10).apply(y) + + # Calculate loss as a vector (to support microbatches in DP-SGD). + vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + # Define mean of loss across minibatch (for reporting through tf.Estimator). + scalar_loss = tf.reduce_mean(vector_loss) + + # Configure the training op (for TRAIN mode). + if mode == tf.estimator.ModeKeys.TRAIN: + + if FLAGS.dpsgd: + # Use DP version of GradientDescentOptimizer. Other optimizers are + # available in dp_optimizer. Most optimizers inheriting from + # tf.train.Optimizer should be wrappable in differentially private + # counterparts by calling dp_optimizer.optimizer_from_args(). + optimizer = dp_optimizer_vectorized.VectorizedDPSGD( + l2_norm_clip=FLAGS.l2_norm_clip, + noise_multiplier=FLAGS.noise_multiplier, + num_microbatches=FLAGS.microbatches, + learning_rate=FLAGS.learning_rate) + opt_loss = vector_loss + else: + optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) + opt_loss = scalar_loss + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss=opt_loss, global_step=global_step) + # In the following, we pass the mean of the loss (scalar_loss) rather than + # the vector_loss because tf.estimator requires a scalar loss. This is only + # used for evaluation and debugging by tf.estimator. The actual loss being + # minimized is opt_loss defined above and passed to optimizer.minimize(). + return tf.estimator.EstimatorSpec(mode=mode, + loss=scalar_loss, + train_op=train_op) + + # Add evaluation metrics (for EVAL mode). + elif mode == tf.estimator.ModeKeys.EVAL: + eval_metric_ops = { + 'accuracy': + tf.metrics.accuracy( + labels=labels, + predictions=tf.argmax(input=logits, axis=1)) + } + + return tf.estimator.EstimatorSpec(mode=mode, + loss=scalar_loss, + eval_metric_ops=eval_metric_ops) + + +def load_mnist(): + """Loads MNIST and preprocesses to combine training and validation data.""" + train, test = tf.keras.datasets.mnist.load_data() + train_data, train_labels = train + test_data, test_labels = test + + train_data = np.array(train_data, dtype=np.float32) / 255 + test_data = np.array(test_data, dtype=np.float32) / 255 + + train_labels = np.array(train_labels, dtype=np.int32) + test_labels = np.array(test_labels, dtype=np.int32) + + assert train_data.min() == 0. + assert train_data.max() == 1. + assert test_data.min() == 0. + assert test_data.max() == 1. + assert train_labels.ndim == 1 + assert test_labels.ndim == 1 + + return train_data, train_labels, test_data, test_labels + + +def main(unused_argv): + tf.logging.set_verbosity(tf.logging.INFO) + if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: + raise ValueError('Number of microbatches should divide evenly batch_size') + + # Load training and test data. + train_data, train_labels, test_data, test_labels = load_mnist() + + # Instantiate the tf.Estimator. + mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, + model_dir=FLAGS.model_dir) + + # Create tf.Estimator input functions for the training and test data. + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': train_data}, + y=train_labels, + batch_size=FLAGS.batch_size, + num_epochs=FLAGS.epochs, + shuffle=True) + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': test_data}, + y=test_labels, + num_epochs=1, + shuffle=False) + + # Training loop. + steps_per_epoch = NUM_TRAIN_EXAMPLES // FLAGS.batch_size + for epoch in range(1, FLAGS.epochs + 1): + # Train the model for one epoch. + mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch) + + # Evaluate the model and print results + eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) + test_accuracy = eval_results['accuracy'] + print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) + + # Compute the privacy budget expended. + if FLAGS.dpsgd: + eps = compute_epsilon(epoch * NUM_TRAIN_EXAMPLES // FLAGS.batch_size) + print('For delta=1e-5, the current epsilon is: %.2f' % eps) + else: + print('Trained with vanilla non-private SGD optimizer') + +if __name__ == '__main__': + app.run(main)