diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD new file mode 100644 index 0000000..3b3754f --- /dev/null +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +py_library( + name = "dp_keras_model", + srcs = [ + "dp_keras_model.py", + ], + deps = [ + "//third_party/py/tensorflow", + "//third_party/tensorflow/compiler/jit:xla_cpu_jit", + "//third_party/tensorflow/compiler/jit:xla_gpu_jit", + ], +) diff --git a/tensorflow_privacy/privacy/keras_models/__init__.py b/tensorflow_privacy/privacy/keras_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py new file mode 100644 index 0000000..d644203 --- /dev/null +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -0,0 +1,94 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Keras Model for vectorized dpsgd with XLA acceleration.""" + +import tensorflow as tf + + +def make_dp_model_class(cls): + """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" + + class DPModelClass(cls): + """A DP version of `cls`, which should be a subclass of `tf.keras.Model`.""" + + def __init__( + self, + l2_norm_clip, + noise_multiplier, + use_xla=True, + *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args + **kwargs): + """Initializes the DPModelClass. + + Args: + l2_norm_clip: Clipping norm (max L2 norm of per microbatch + gradients). + noise_multiplier: Ratio of the standard deviation to the clipping + norm. + use_xla: If True, compiles train_step to XLA. + """ + super(DPModelClass, self).__init__(*args, **kwargs) + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + + if use_xla: + self.train_step = tf.function( + self.train_step, experimental_compile=True) + + def _process_per_example_grads(self, grads): + grads_flat = tf.nest.flatten(grads) + squared_l2_norms = [ + tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat + ] + global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) + div = tf.maximum(global_norm / self._l2_norm_clip, 1.) + clipped_flat = [g / div for g in grads_flat] + return tf.nest.pack_sequence_as(grads, clipped_flat) + + def _reduce_per_example_grads(self, stacked_grads): + summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0) + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal( + tf.shape(input=summed_grads), stddev=noise_stddev) + noised_grads = summed_grads + noise + return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype) + + def _compute_per_example_grads(self, data): + x, y = data + with tf.GradientTape() as tape: + # We need to add the extra dimension to x and y because model + # expects batched input. + y_pred = self(x[None], training=True) + loss = self.compiled_loss( + y[None], y_pred, regularization_losses=self.losses) + + grads_list = tape.gradient(loss, self.trainable_variables) + clipped_grads = self._process_per_example_grads(grads_list) + return tf.squeeze(y_pred, axis=0), loss, clipped_grads + + def train_step(self, data): + _, y = data + y_pred, _, per_eg_grads = tf.vectorized_map( + self._compute_per_example_grads, data) + grads = tf.nest.map_structure(self._reduce_per_example_grads, + per_eg_grads) + self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) + self.compiled_metrics.update_state(y, y_pred) + return {m.name: m.result() for m in self.metrics} + + return DPModelClass + + +DPModel = make_dp_model_class(tf.keras.Model) +DPSequential = make_dp_model_class(tf.keras.Sequential) diff --git a/tutorials/mnist_dpsgd_tutorial_keras_model.py b/tutorials/mnist_dpsgd_tutorial_keras_model.py new file mode 100644 index 0000000..de59a09 --- /dev/null +++ b/tutorials/mnist_dpsgd_tutorial_keras_model.py @@ -0,0 +1,145 @@ +# Copyright 2021, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training a CNN on MNIST with Keras and the DP SGD optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags +from absl import logging + +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.keras_models.dp_keras_model import DPSequential + +flags.DEFINE_boolean( + 'dpsgd', True, 'If True, train with DP-SGD. If False, ' + 'train with vanilla SGD.') +flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training') +flags.DEFINE_float('noise_multiplier', 0.1, + 'Ratio of the standard deviation to the clipping norm') +flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') +flags.DEFINE_integer('batch_size', 250, 'Batch size') +flags.DEFINE_integer('epochs', 60, 'Number of epochs') +flags.DEFINE_integer( + 'microbatches', 250, 'Number of microbatches ' + '(must evenly divide batch_size)') +flags.DEFINE_string('model_dir', None, 'Model directory') + +FLAGS = flags.FLAGS + + +def compute_epsilon(steps): + """Computes epsilon value for given hyperparameters.""" + if FLAGS.noise_multiplier == 0.0: + return float('inf') + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + sampling_probability = FLAGS.batch_size / 60000 + rdp = compute_rdp( + q=sampling_probability, + noise_multiplier=FLAGS.noise_multiplier, + steps=steps, + orders=orders) + # Delta is set to 1e-5 because MNIST has 60000 training points. + return get_privacy_spent(orders, rdp, target_delta=1e-5)[0] + + +def load_mnist(): + """Loads MNIST and preprocesses to combine training and validation data.""" + train, test = tf.keras.datasets.mnist.load_data() + train_data, train_labels = train + test_data, test_labels = test + + train_data = np.array(train_data, dtype=np.float32) / 255 + test_data = np.array(test_data, dtype=np.float32) / 255 + + train_data = train_data.reshape((train_data.shape[0], 28, 28, 1)) + test_data = test_data.reshape((test_data.shape[0], 28, 28, 1)) + + train_labels = np.array(train_labels, dtype=np.int32) + test_labels = np.array(test_labels, dtype=np.int32) + + train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10) + test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10) + + assert train_data.min() == 0. + assert train_data.max() == 1. + assert test_data.min() == 0. + assert test_data.max() == 1. + + return train_data, train_labels, test_data, test_labels + + +def main(unused_argv): + logging.set_verbosity(logging.INFO) + if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: + raise ValueError('Number of microbatches should divide evenly batch_size') + + # Load training and test data. + train_data, train_labels, test_data, test_labels = load_mnist() + + # Define a sequential Keras model + layers = [ + tf.keras.layers.Conv2D( + 16, + 8, + strides=2, + padding='same', + activation='relu', + input_shape=(28, 28, 1)), + tf.keras.layers.MaxPool2D(2, 1), + tf.keras.layers.Conv2D( + 32, 4, strides=2, padding='valid', activation='relu'), + tf.keras.layers.MaxPool2D(2, 1), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(32, activation='relu'), + tf.keras.layers.Dense(10) + ] + if FLAGS.dpsgd: + model = DPSequential( + l2_norm_clip=FLAGS.l2_norm_clip, + noise_multiplier=FLAGS.noise_multiplier, + layers=layers) + else: + model = tf.keras.Sequential(layers=layers) + + optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.learning_rate) + loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) + + # Compile model with Keras + model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) + + # Train model with Keras + model.fit( + train_data, + train_labels, + epochs=FLAGS.epochs, + validation_data=(test_data, test_labels), + batch_size=FLAGS.batch_size) + + # Compute the privacy budget expended. + if FLAGS.dpsgd: + eps = compute_epsilon(FLAGS.epochs * 60000 // FLAGS.batch_size) + print('For delta=1e-5, the current epsilon is: %.2f' % eps) + else: + print('Trained with vanilla non-private SGD optimizer') + + +if __name__ == '__main__': + app.run(main)