forked from 626_privacy/tensorflow_privacy
Add support for the Eager mode
PiperOrigin-RevId: 235733975
This commit is contained in:
parent
bfba26801d
commit
c2d4b17881
2 changed files with 221 additions and 37 deletions
|
@ -58,51 +58,89 @@ def make_optimizer_class(cls):
|
|||
gate_gradients=tf.train.Optimizer.GATE_OP,
|
||||
aggregation_method=None,
|
||||
colocate_gradients_with_ops=False,
|
||||
grad_loss=None):
|
||||
grad_loss=None,
|
||||
gradient_tape=None):
|
||||
if callable(loss):
|
||||
# TF is running in Eager mode, check we received a vanilla tape.
|
||||
if not gradient_tape:
|
||||
raise ValueError('When in Eager mode, a tape needs to be passed.')
|
||||
|
||||
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||
# we sampled each microbatch from the appropriate binomial distribution,
|
||||
# although that still wouldn't be quite correct because it would be
|
||||
# sampling from the dataset without replacement.
|
||||
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||
sample_params = (
|
||||
self._dp_average_query.derive_sample_params(self._global_state))
|
||||
vector_loss = loss()
|
||||
sample_state = self._dp_average_query.initial_sample_state(
|
||||
self._global_state, var_list)
|
||||
microbatches_losses = tf.reshape(vector_loss,
|
||||
[self._num_microbatches, -1])
|
||||
sample_params = (
|
||||
self._dp_average_query.derive_sample_params(self._global_state))
|
||||
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
grads, _ = zip(*super(cls, self).compute_gradients(
|
||||
tf.reduce_mean(tf.gather(microbatches_losses,
|
||||
[i])), var_list, gate_gradients,
|
||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||
grads_list = list(grads)
|
||||
sample_state = self._dp_average_query.accumulate_record(
|
||||
sample_params, sample_state, grads_list)
|
||||
return sample_state
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
microbatch_loss = tf.gather(microbatches_losses, [i])
|
||||
grads = gradient_tape.gradient(microbatch_loss, var_list)
|
||||
sample_state = self._dp_average_query.accumulate_record(sample_params,
|
||||
sample_state,
|
||||
grads)
|
||||
return sample_state
|
||||
|
||||
if var_list is None:
|
||||
var_list = (
|
||||
tf.trainable_variables() + tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
sample_state = self._dp_average_query.initial_sample_state(
|
||||
self._global_state, var_list)
|
||||
|
||||
if self._unroll_microbatches:
|
||||
for idx in range(self._num_microbatches):
|
||||
sample_state = process_microbatch(idx, sample_state)
|
||||
|
||||
final_grads, self._global_state = (
|
||||
self._dp_average_query.get_noised_result(sample_state,
|
||||
self._global_state))
|
||||
|
||||
grads_and_vars = list(zip(final_grads, var_list))
|
||||
return grads_and_vars
|
||||
|
||||
else:
|
||||
# Use of while_loop here requires that sample_state be a nested
|
||||
# structure of tensors. In general, we would prefer to allow it to be
|
||||
# an arbitrary opaque type.
|
||||
cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
|
||||
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)]
|
||||
idx = tf.constant(0)
|
||||
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
||||
# TF is running in graph mode, check we did not receive a gradient tape.
|
||||
if gradient_tape:
|
||||
raise ValueError('When in graph mode, a tape should not be passed.')
|
||||
|
||||
final_grads, self._global_state = (
|
||||
self._dp_average_query.get_noised_result(
|
||||
sample_state, self._global_state))
|
||||
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||
# we sampled each microbatch from the appropriate binomial distribution,
|
||||
# although that still wouldn't be quite correct because it would be
|
||||
# sampling from the dataset without replacement.
|
||||
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||
sample_params = (
|
||||
self._dp_average_query.derive_sample_params(self._global_state))
|
||||
|
||||
return list(zip(final_grads, var_list))
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
grads, _ = zip(*super(cls, self).compute_gradients(
|
||||
tf.reduce_mean(tf.gather(microbatches_losses,
|
||||
[i])), var_list, gate_gradients,
|
||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||
grads_list = list(grads)
|
||||
sample_state = self._dp_average_query.accumulate_record(
|
||||
sample_params, sample_state, grads_list)
|
||||
return sample_state
|
||||
|
||||
if var_list is None:
|
||||
var_list = (
|
||||
tf.trainable_variables() + tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
|
||||
sample_state = self._dp_average_query.initial_sample_state(
|
||||
self._global_state, var_list)
|
||||
|
||||
if self._unroll_microbatches:
|
||||
for idx in range(self._num_microbatches):
|
||||
sample_state = process_microbatch(idx, sample_state)
|
||||
else:
|
||||
# Use of while_loop here requires that sample_state be a nested
|
||||
# structure of tensors. In general, we would prefer to allow it to be
|
||||
# an arbitrary opaque type.
|
||||
cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
|
||||
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
|
||||
idx = tf.constant(0)
|
||||
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
||||
|
||||
final_grads, self._global_state = (
|
||||
self._dp_average_query.get_noised_result(
|
||||
sample_state, self._global_state))
|
||||
|
||||
return list(zip(final_grads, var_list))
|
||||
|
||||
return DPOptimizerClass
|
||||
|
||||
|
|
146
tutorials/mnist_dpsgd_tutorial_eager.py
Normal file
146
tutorials/mnist_dpsgd_tutorial_eager.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Training a CNN on MNIST in TF Eager mode with DP-SGD optimizer."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.analysis.rdp_accountant import compute_rdp
|
||||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||
from privacy.optimizers.gaussian_query import GaussianAverageQuery
|
||||
|
||||
tf.enable_eager_execution()
|
||||
|
||||
tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||
'train with vanilla SGD.')
|
||||
tf.flags.DEFINE_float('learning_rate', 0.15, 'Learning rate for training')
|
||||
tf.flags.DEFINE_float('noise_multiplier', 1.1,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
tf.flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||
tf.flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||
tf.flags.DEFINE_integer('microbatches', 250, 'Number of microbatches '
|
||||
'(must evenly divide batch_size)')
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
|
||||
def compute_epsilon(steps):
|
||||
"""Computes epsilon value for given hyperparameters."""
|
||||
if FLAGS.noise_multiplier == 0.0:
|
||||
return float('inf')
|
||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||
sampling_probability = FLAGS.batch_size / 60000
|
||||
rdp = compute_rdp(q=sampling_probability,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
steps=steps,
|
||||
orders=orders)
|
||||
# Delta is set to 1e-5 because MNIST has 60000 training points.
|
||||
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||
|
||||
|
||||
def main(_):
|
||||
# Fetch the mnist data
|
||||
train, test = tf.keras.datasets.mnist.load_data()
|
||||
train_images, train_labels = train
|
||||
test_images, test_labels = test
|
||||
|
||||
# Create a dataset object and batch for the training data
|
||||
dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(tf.cast(train_images[..., tf.newaxis]/255, tf.float32),
|
||||
tf.cast(train_labels, tf.int64)))
|
||||
dataset = dataset.shuffle(1000).batch(FLAGS.batch_size)
|
||||
|
||||
# Create a dataset object and batch for the test data
|
||||
eval_dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(tf.cast(test_images[..., tf.newaxis]/255, tf.float32),
|
||||
tf.cast(test_labels, tf.int64)))
|
||||
eval_dataset = eval_dataset.batch(10000)
|
||||
|
||||
# Define the model using tf.keras.layers
|
||||
mnist_model = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(16, 8,
|
||||
strides=2,
|
||||
padding='same',
|
||||
activation='relu'),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Conv2D(32, 4, strides=2, activation='relu'),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(32, activation='relu'),
|
||||
tf.keras.layers.Dense(10)
|
||||
])
|
||||
|
||||
# Instantiate the optimizer
|
||||
if FLAGS.dpsgd:
|
||||
dp_average_query = GaussianAverageQuery(
|
||||
FLAGS.l2_norm_clip,
|
||||
FLAGS.l2_norm_clip * FLAGS.noise_multiplier,
|
||||
FLAGS.microbatches)
|
||||
opt = DPGradientDescentOptimizer(
|
||||
dp_average_query,
|
||||
FLAGS.microbatches,
|
||||
learning_rate=FLAGS.learning_rate)
|
||||
else:
|
||||
opt = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
|
||||
# Training loop.
|
||||
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Train the model for one epoch.
|
||||
for (_, (images, labels)) in enumerate(dataset.take(-1)):
|
||||
with tf.GradientTape(persistent=True) as gradient_tape:
|
||||
# This dummy call is needed to obtain the var list.
|
||||
logits = mnist_model(images, training=True)
|
||||
var_list = mnist_model.trainable_variables
|
||||
|
||||
# In Eager mode, the optimizer takes a function that returns the loss.
|
||||
def loss_fn():
|
||||
logits = mnist_model(images, training=True) # pylint: disable=undefined-loop-variable,cell-var-from-loop
|
||||
loss = tf.losses.sparse_softmax_cross_entropy(
|
||||
labels, logits, reduction=tf.losses.Reduction.NONE) # pylint: disable=undefined-loop-variable,cell-var-from-loop
|
||||
# If training without privacy, the loss is a scalar not a vector.
|
||||
if not FLAGS.dpsgd:
|
||||
loss = tf.reduce_mean(loss)
|
||||
return loss
|
||||
|
||||
if FLAGS.dpsgd:
|
||||
grads_and_vars = opt.compute_gradients(loss_fn, var_list,
|
||||
gradient_tape=gradient_tape)
|
||||
else:
|
||||
grads_and_vars = opt.compute_gradients(loss_fn, var_list)
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
opt.apply_gradients(grads_and_vars, global_step=global_step)
|
||||
|
||||
# Evaluate the model and print results
|
||||
for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):
|
||||
logits = mnist_model(images, training=False)
|
||||
correct_preds = tf.equal(tf.argmax(logits, axis=1), labels)
|
||||
test_accuracy = np.mean(correct_preds.numpy())
|
||||
print('Test accuracy after epoch %d is: %.3f' % (epoch, test_accuracy))
|
||||
|
||||
# Compute the privacy budget expended so far.
|
||||
if FLAGS.dpsgd:
|
||||
eps = compute_epsilon(epoch * steps_per_epoch)
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
else:
|
||||
print('Trained with vanilla non-private SGD optimizer')
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run(main)
|
Loading…
Reference in a new issue