Merge pull request #147 from TheSalon:master
PiperOrigin-RevId: 351680116
This commit is contained in:
commit
aed49d0087
4 changed files with 254 additions and 0 deletions
15
tensorflow_privacy/privacy/keras_models/BUILD
Normal file
15
tensorflow_privacy/privacy/keras_models/BUILD
Normal file
|
@ -0,0 +1,15 @@
|
|||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "dp_keras_model",
|
||||
srcs = [
|
||||
"dp_keras_model.py",
|
||||
],
|
||||
deps = [
|
||||
"//third_party/py/tensorflow",
|
||||
"//third_party/tensorflow/compiler/jit:xla_cpu_jit",
|
||||
"//third_party/tensorflow/compiler/jit:xla_gpu_jit",
|
||||
],
|
||||
)
|
0
tensorflow_privacy/privacy/keras_models/__init__.py
Normal file
0
tensorflow_privacy/privacy/keras_models/__init__.py
Normal file
94
tensorflow_privacy/privacy/keras_models/dp_keras_model.py
Normal file
94
tensorflow_privacy/privacy/keras_models/dp_keras_model.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keras Model for vectorized dpsgd with XLA acceleration."""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def make_dp_model_class(cls):
|
||||
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
||||
|
||||
class DPModelClass(cls):
|
||||
"""A DP version of `cls`, which should be a subclass of `tf.keras.Model`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
use_xla=True,
|
||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||
**kwargs):
|
||||
"""Initializes the DPModelClass.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch
|
||||
gradients).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping
|
||||
norm.
|
||||
use_xla: If True, compiles train_step to XLA.
|
||||
"""
|
||||
super(DPModelClass, self).__init__(*args, **kwargs)
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._noise_multiplier = noise_multiplier
|
||||
|
||||
if use_xla:
|
||||
self.train_step = tf.function(
|
||||
self.train_step, experimental_compile=True)
|
||||
|
||||
def _process_per_example_grads(self, grads):
|
||||
grads_flat = tf.nest.flatten(grads)
|
||||
squared_l2_norms = [
|
||||
tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat
|
||||
]
|
||||
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
|
||||
clipped_flat = [g / div for g in grads_flat]
|
||||
return tf.nest.pack_sequence_as(grads, clipped_flat)
|
||||
|
||||
def _reduce_per_example_grads(self, stacked_grads):
|
||||
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
|
||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||
noise = tf.random.normal(
|
||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
||||
noised_grads = summed_grads + noise
|
||||
return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype)
|
||||
|
||||
def _compute_per_example_grads(self, data):
|
||||
x, y = data
|
||||
with tf.GradientTape() as tape:
|
||||
# We need to add the extra dimension to x and y because model
|
||||
# expects batched input.
|
||||
y_pred = self(x[None], training=True)
|
||||
loss = self.compiled_loss(
|
||||
y[None], y_pred, regularization_losses=self.losses)
|
||||
|
||||
grads_list = tape.gradient(loss, self.trainable_variables)
|
||||
clipped_grads = self._process_per_example_grads(grads_list)
|
||||
return tf.squeeze(y_pred, axis=0), loss, clipped_grads
|
||||
|
||||
def train_step(self, data):
|
||||
_, y = data
|
||||
y_pred, _, per_eg_grads = tf.vectorized_map(
|
||||
self._compute_per_example_grads, data)
|
||||
grads = tf.nest.map_structure(self._reduce_per_example_grads,
|
||||
per_eg_grads)
|
||||
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
||||
self.compiled_metrics.update_state(y, y_pred)
|
||||
return {m.name: m.result() for m in self.metrics}
|
||||
|
||||
return DPModelClass
|
||||
|
||||
|
||||
DPModel = make_dp_model_class(tf.keras.Model)
|
||||
DPSequential = make_dp_model_class(tf.keras.Sequential)
|
145
tutorials/mnist_dpsgd_tutorial_keras_model.py
Normal file
145
tutorials/mnist_dpsgd_tutorial_keras_model.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Training a CNN on MNIST with Keras and the DP SGD optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.keras_models.dp_keras_model import DPSequential
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||
'train with vanilla SGD.')
|
||||
flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
|
||||
flags.DEFINE_float('noise_multiplier', 0.1,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||
flags.DEFINE_integer(
|
||||
'microbatches', 250, 'Number of microbatches '
|
||||
'(must evenly divide batch_size)')
|
||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def compute_epsilon(steps):
|
||||
"""Computes epsilon value for given hyperparameters."""
|
||||
if FLAGS.noise_multiplier == 0.0:
|
||||
return float('inf')
|
||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||
sampling_probability = FLAGS.batch_size / 60000
|
||||
rdp = compute_rdp(
|
||||
q=sampling_probability,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
steps=steps,
|
||||
orders=orders)
|
||||
# Delta is set to 1e-5 because MNIST has 60000 training points.
|
||||
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||
|
||||
|
||||
def load_mnist():
|
||||
"""Loads MNIST and preprocesses to combine training and validation data."""
|
||||
train, test = tf.keras.datasets.mnist.load_data()
|
||||
train_data, train_labels = train
|
||||
test_data, test_labels = test
|
||||
|
||||
train_data = np.array(train_data, dtype=np.float32) / 255
|
||||
test_data = np.array(test_data, dtype=np.float32) / 255
|
||||
|
||||
train_data = train_data.reshape((train_data.shape[0], 28, 28, 1))
|
||||
test_data = test_data.reshape((test_data.shape[0], 28, 28, 1))
|
||||
|
||||
train_labels = np.array(train_labels, dtype=np.int32)
|
||||
test_labels = np.array(test_labels, dtype=np.int32)
|
||||
|
||||
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
|
||||
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)
|
||||
|
||||
assert train_data.min() == 0.
|
||||
assert train_data.max() == 1.
|
||||
assert test_data.min() == 0.
|
||||
assert test_data.max() == 1.
|
||||
|
||||
return train_data, train_labels, test_data, test_labels
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||
|
||||
# Load training and test data.
|
||||
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||
|
||||
# Define a sequential Keras model
|
||||
layers = [
|
||||
tf.keras.layers.Conv2D(
|
||||
16,
|
||||
8,
|
||||
strides=2,
|
||||
padding='same',
|
||||
activation='relu',
|
||||
input_shape=(28, 28, 1)),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Conv2D(
|
||||
32, 4, strides=2, padding='valid', activation='relu'),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(32, activation='relu'),
|
||||
tf.keras.layers.Dense(10)
|
||||
]
|
||||
if FLAGS.dpsgd:
|
||||
model = DPSequential(
|
||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
layers=layers)
|
||||
else:
|
||||
model = tf.keras.Sequential(layers=layers)
|
||||
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.learning_rate)
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||
|
||||
# Compile model with Keras
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(
|
||||
train_data,
|
||||
train_labels,
|
||||
epochs=FLAGS.epochs,
|
||||
validation_data=(test_data, test_labels),
|
||||
batch_size=FLAGS.batch_size)
|
||||
|
||||
# Compute the privacy budget expended.
|
||||
if FLAGS.dpsgd:
|
||||
eps = compute_epsilon(FLAGS.epochs * 60000 // FLAGS.batch_size)
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
else:
|
||||
print('Trained with vanilla non-private SGD optimizer')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
Loading…
Reference in a new issue