From 574718706df9b4d02d16a359d18611499357a33b Mon Sep 17 00:00:00 2001 From: pranav subramani Date: Mon, 4 Jan 2021 19:32:53 -0700 Subject: [PATCH 1/7] creating keras models directory --- tensorflow_privacy/privacy/keras_models/__init__.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 tensorflow_privacy/privacy/keras_models/__init__.py diff --git a/tensorflow_privacy/privacy/keras_models/__init__.py b/tensorflow_privacy/privacy/keras_models/__init__.py new file mode 100644 index 0000000..139597f --- /dev/null +++ b/tensorflow_privacy/privacy/keras_models/__init__.py @@ -0,0 +1,2 @@ + + From 7a00a1cfefbe789afc71ac90e88099fe52011fd0 Mon Sep 17 00:00:00 2001 From: pranav subramani Date: Tue, 5 Jan 2021 13:13:00 -0700 Subject: [PATCH 2/7] adding keras vectorized model initial commit --- .../dp_optimizer_keras_vectorized.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py diff --git a/tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py b/tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py new file mode 100644 index 0000000..96461ed --- /dev/null +++ b/tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py @@ -0,0 +1,30 @@ +import tensorflow as tf + + +def make_dp_model_class(cls): + class DPModelClass(cls): + def __init__(self, l2_norm_clip, noise_multiplier, use_xla=True, *args, **kwargs): + super(DPModelClass, self).__init__(*args, **kwargs) + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + + if use_xla: + self.train_step = tf.function( + self.train_step, experimental_compile=True) + + def process_per_example_grads(self, grads): + grads_flat = tf.nest.flatten(grads) + squared_l2_norms = [tf.reduce_sum( + input_tensor=tf.square(g)) for g in grads_flat] + global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) + div = tf.maximum(global_norm / self._l2_norm_clip, 1.) + clipped_flat = [g / div for g in grads_flat] + return tf.nest.pack_sequence_as(grads, clipped_flat) + + def reduce_per_example_grads(self, stacked_grads): + summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0) + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal( + tf.shape(input=summed_grads), stddev=noise_stddev) + noised_grads = summed_grads + noise + return noised_grads / tf.cast(stacked_grads.shape[0], tf.float32) From 7dad2d18e8b8af6c33b24d0d1b018c4d7d95f95a Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Tue, 5 Jan 2021 17:42:10 -0500 Subject: [PATCH 3/7] Update privacy/keras_models. --- .../privacy/keras_models/dp_keras_model.py | 59 +++++++++++++++++++ .../dp_optimizer_keras_vectorized.py | 30 ---------- 2 files changed, 59 insertions(+), 30 deletions(-) create mode 100644 tensorflow_privacy/privacy/keras_models/dp_keras_model.py delete mode 100644 tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py new file mode 100644 index 0000000..99b92c9 --- /dev/null +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -0,0 +1,59 @@ +import tensorflow as tf + + +def make_dp_model_class(cls): + class DPModelClass(cls): + def __init__(self, l2_norm_clip, noise_multiplier, use_xla=True, *args, **kwargs): + super(DPModelClass, self).__init__(*args, **kwargs) + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + + if use_xla: + self.train_step = tf.function( + self.train_step, experimental_compile=True) + + def process_per_example_grads(self, grads): + grads_flat = tf.nest.flatten(grads) + squared_l2_norms = [tf.reduce_sum( + input_tensor=tf.square(g)) for g in grads_flat] + global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) + div = tf.maximum(global_norm / self._l2_norm_clip, 1.) + clipped_flat = [g / div for g in grads_flat] + return tf.nest.pack_sequence_as(grads, clipped_flat) + + def reduce_per_example_grads(self, stacked_grads): + summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0) + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal( + tf.shape(input=summed_grads), stddev=noise_stddev) + noised_grads = summed_grads + noise + return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype) + + def compute_per_example_grads(self, data): + x, y = data + with tf.GradientTape() as tape: + # We need to add the extra dimension to x and y because model + # expects batched input. + y_pred = self(x[None], training=True) + loss = self.compiled_loss(y[None], y_pred, + regularization_losses=self.losses) + + grads_list = tape.gradient(loss, self.trainable_variables) + clipped_grads = self.process_per_example_grads(grads_list) + return tf.squeeze(y_pred, axis=0), loss, clipped_grads + + def train_step(self, data): + x, y = data + y_pred, per_eg_loss, per_eg_grads = tf.vectorized_map( + self.compute_per_example_grads, data) + loss = tf.reduce_mean(per_eg_loss, axis=0) + grads = tf.nest.map_structure(self.reduce_per_example_grads, per_eg_grads) + self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) + self.compiled_metrics.update_state(y, y_pred) + return {m.name: m.result() for m in self.metrics} + + return DPModelClass + + +DPModel = make_dp_model_class(tf.keras.Model) +DPSequential = make_dp_model_class(tf.keras.Sequential) diff --git a/tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py b/tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py deleted file mode 100644 index 96461ed..0000000 --- a/tensorflow_privacy/privacy/keras_models/dp_optimizer_keras_vectorized.py +++ /dev/null @@ -1,30 +0,0 @@ -import tensorflow as tf - - -def make_dp_model_class(cls): - class DPModelClass(cls): - def __init__(self, l2_norm_clip, noise_multiplier, use_xla=True, *args, **kwargs): - super(DPModelClass, self).__init__(*args, **kwargs) - self._l2_norm_clip = l2_norm_clip - self._noise_multiplier = noise_multiplier - - if use_xla: - self.train_step = tf.function( - self.train_step, experimental_compile=True) - - def process_per_example_grads(self, grads): - grads_flat = tf.nest.flatten(grads) - squared_l2_norms = [tf.reduce_sum( - input_tensor=tf.square(g)) for g in grads_flat] - global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) - div = tf.maximum(global_norm / self._l2_norm_clip, 1.) - clipped_flat = [g / div for g in grads_flat] - return tf.nest.pack_sequence_as(grads, clipped_flat) - - def reduce_per_example_grads(self, stacked_grads): - summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0) - noise_stddev = self._l2_norm_clip * self._noise_multiplier - noise = tf.random.normal( - tf.shape(input=summed_grads), stddev=noise_stddev) - noised_grads = summed_grads + noise - return noised_grads / tf.cast(stacked_grads.shape[0], tf.float32) From 9d871b28c12fef304f6520f35f053e4c4cf26bb7 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Tue, 5 Jan 2021 17:43:00 -0500 Subject: [PATCH 4/7] Add keras_models example to tutorials. --- tutorials/mnist_dpsgd_tutorial_keras_model.py | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 tutorials/mnist_dpsgd_tutorial_keras_model.py diff --git a/tutorials/mnist_dpsgd_tutorial_keras_model.py b/tutorials/mnist_dpsgd_tutorial_keras_model.py new file mode 100644 index 0000000..402f795 --- /dev/null +++ b/tutorials/mnist_dpsgd_tutorial_keras_model.py @@ -0,0 +1,140 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training a CNN on MNIST with Keras and the DP SGD optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags +from absl import logging + +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.keras_models.dp_keras_model import DPSequential + +flags.DEFINE_boolean( + 'dpsgd', True, 'If True, train with DP-SGD. If False, ' + 'train with vanilla SGD.') +flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training') +flags.DEFINE_float('noise_multiplier', 0.1, + 'Ratio of the standard deviation to the clipping norm') +flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') +flags.DEFINE_integer('batch_size', 250, 'Batch size') +flags.DEFINE_integer('epochs', 60, 'Number of epochs') +flags.DEFINE_integer( + 'microbatches', 250, 'Number of microbatches ' + '(must evenly divide batch_size)') +flags.DEFINE_string('model_dir', None, 'Model directory') + +FLAGS = flags.FLAGS + + +def compute_epsilon(steps): + """Computes epsilon value for given hyperparameters.""" + if FLAGS.noise_multiplier == 0.0: + return float('inf') + orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) + sampling_probability = FLAGS.batch_size / 60000 + rdp = compute_rdp(q=sampling_probability, + noise_multiplier=FLAGS.noise_multiplier, + steps=steps, + orders=orders) + # Delta is set to 1e-5 because MNIST has 60000 training points. + return get_privacy_spent(orders, rdp, target_delta=1e-5)[0] + + +def load_mnist(): + """Loads MNIST and preprocesses to combine training and validation data.""" + train, test = tf.keras.datasets.mnist.load_data() + train_data, train_labels = train + test_data, test_labels = test + + train_data = np.array(train_data, dtype=np.float32) / 255 + test_data = np.array(test_data, dtype=np.float32) / 255 + + train_data = train_data.reshape((train_data.shape[0], 28, 28, 1)) + test_data = test_data.reshape((test_data.shape[0], 28, 28, 1)) + + train_labels = np.array(train_labels, dtype=np.int32) + test_labels = np.array(test_labels, dtype=np.int32) + + train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10) + test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10) + + assert train_data.min() == 0. + assert train_data.max() == 1. + assert test_data.min() == 0. + assert test_data.max() == 1. + + return train_data, train_labels, test_data, test_labels + + +def main(unused_argv): + logging.set_verbosity(logging.INFO) + if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: + raise ValueError('Number of microbatches should divide evenly batch_size') + + # Load training and test data. + train_data, train_labels, test_data, test_labels = load_mnist() + + # Define a sequential Keras model + layers = [ + tf.keras.layers.Conv2D(16, 8, + strides=2, + padding='same', + activation='relu', + input_shape=(28, 28, 1)), + tf.keras.layers.MaxPool2D(2, 1), + tf.keras.layers.Conv2D(32, 4, + strides=2, + padding='valid', + activation='relu'), + tf.keras.layers.MaxPool2D(2, 1), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(32, activation='relu'), + tf.keras.layers.Dense(10) + ] + if FLAGS.dpsgd: + model = DPSequential(l2_norm_clip=FLAGS.l2_norm_clip, + noise_multiplier=FLAGS.noise_multiplier, + layers=layers) + else: + model = tf.keras.Sequential(layers=layers) + + optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.learning_rate) + loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) + + # Compile model with Keras + model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) + + # Train model with Keras + model.fit(train_data, train_labels, + epochs=FLAGS.epochs, + validation_data=(test_data, test_labels), + batch_size=FLAGS.batch_size) + + # Compute the privacy budget expended. + if FLAGS.dpsgd: + eps = compute_epsilon(FLAGS.epochs * 60000 // FLAGS.batch_size) + print('For delta=1e-5, the current epsilon is: %.2f' % eps) + else: + print('Trained with vanilla non-private SGD optimizer') + +if __name__ == '__main__': + app.run(main) From 6982e027b58ab839dbea00f04864ded4688f2690 Mon Sep 17 00:00:00 2001 From: pranav subramani Date: Fri, 8 Jan 2021 00:22:44 -0700 Subject: [PATCH 5/7] update dp keras model --- .../privacy/keras_models/dp_keras_model.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 99b92c9..6565082 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -27,7 +27,7 @@ def make_dp_model_class(cls): noise = tf.random.normal( tf.shape(input=summed_grads), stddev=noise_stddev) noised_grads = summed_grads + noise - return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype) + return noised_grads / tf.cast(stacked_grads.shape[0], tf.float32) def compute_per_example_grads(self, data): x, y = data @@ -35,23 +35,24 @@ def make_dp_model_class(cls): # We need to add the extra dimension to x and y because model # expects batched input. y_pred = self(x[None], training=True) - loss = self.compiled_loss(y[None], y_pred, + loss = self.compiled_loss(y[None], y_pred, regularization_losses=self.losses) - + grads_list = tape.gradient(loss, self.trainable_variables) clipped_grads = self.process_per_example_grads(grads_list) return tf.squeeze(y_pred, axis=0), loss, clipped_grads - + def train_step(self, data): x, y = data y_pred, per_eg_loss, per_eg_grads = tf.vectorized_map( self.compute_per_example_grads, data) loss = tf.reduce_mean(per_eg_loss, axis=0) - grads = tf.nest.map_structure(self.reduce_per_example_grads, per_eg_grads) + grads = tf.nest.map_structure( + self.reduce_per_example_grads, per_eg_grads) self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) self.compiled_metrics.update_state(y, y_pred) return {m.name: m.result() for m in self.metrics} - + return DPModelClass From 13b3a04a3e0d3072ce99ebd40d0a4b3939b7a77c Mon Sep 17 00:00:00 2001 From: pranav subramani Date: Fri, 8 Jan 2021 00:23:32 -0700 Subject: [PATCH 6/7] update keras model --- tensorflow_privacy/privacy/keras_models/dp_keras_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 6565082..b854f54 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -27,7 +27,7 @@ def make_dp_model_class(cls): noise = tf.random.normal( tf.shape(input=summed_grads), stddev=noise_stddev) noised_grads = summed_grads + noise - return noised_grads / tf.cast(stacked_grads.shape[0], tf.float32) + return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype) def compute_per_example_grads(self, data): x, y = data From 78ec3fa58a1fdeec0311dfa2096ddfdb91e1aaa5 Mon Sep 17 00:00:00 2001 From: pranav subramani Date: Fri, 8 Jan 2021 00:24:52 -0700 Subject: [PATCH 7/7] update dp keras model --- tensorflow_privacy/privacy/keras_models/dp_keras_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index b854f54..19b6fd9 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -1,3 +1,7 @@ +''' +Keras Model for vectorized dpsgd with XLA acceleration +''' + import tensorflow as tf