A callback and a function to be called in the end of training for keras to perform membership inference attack.

PiperOrigin-RevId: 323805663
This commit is contained in:
Shuang Song 2020-07-29 09:41:40 -07:00 committed by Steve Chien
parent dcbfaa3f5e
commit cea9e01670
3 changed files with 284 additions and 0 deletions

View file

@ -0,0 +1,108 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""A callback and a function in keras for membership inference attack."""
from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
def calculate_losses(model, data, labels):
"""Calculate losses of model prediction on data, provided true labels.
Args:
model: model to make prediction
data: samples
labels: true labels of samples (integer valued)
Returns:
preds: probability vector of each sample
loss: cross entropy loss of each sample
"""
pred = model.predict(data)
loss = log_loss(labels, pred)
return pred, loss
class MembershipInferenceCallback(tf.keras.callbacks.Callback):
"""Callback to perform membership inference attack on epoch end."""
def __init__(self, in_train, out_train, attack_classifiers,
tensorboard_dir=None):
"""Initalizes the callback.
Args:
in_train: (in_training samples, in_training labels)
out_train: (out_training samples, out_training labels)
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
tensorboard_dir: directory for tensorboard summary
"""
self._in_train_data, self._in_train_labels = in_train
self._out_train_data, self._out_train_labels = out_train
self._attack_classifiers = attack_classifiers
# Setup tensorboard writer if tensorboard_dir is specified
if tensorboard_dir:
with tf.Graph().as_default():
self._writer = tf.summary.FileWriter(tensorboard_dir)
logging.info('Will write to tensorboard.')
else:
self._writer = None
def on_epoch_end(self, epoch, logs=None):
results = run_attack_on_keras_model(
self.model,
(self._in_train_data, self._in_train_labels),
(self._out_train_data, self._out_train_labels),
self._attack_classifiers)
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
logging.info(results)
# Write to tensorboard if tensorboard_dir is specified
write_to_tensorboard(self._writer, ['attack advantage'],
[results['all_thresh_loss_advantage']], epoch)
def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers):
"""Performs the attack on a trained model.
Args:
model: model to be tested
in_train: a (in_training samples, in_training labels) tuple
out_train: a (out_training samples, out_training labels) tuple
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
Returns:
Results of the attack
"""
in_train_data, in_train_labels = in_train
out_train_data, out_train_labels = out_train
# Compute predictions and losses
in_train_pred, in_train_loss = calculate_losses(model, in_train_data,
in_train_labels)
out_train_pred, out_train_loss = calculate_losses(model, out_train_data,
out_train_labels)
results = mia.run_all_attacks(in_train_loss, out_train_loss,
in_train_pred, out_train_pred,
in_train_labels, out_train_labels,
attack_classifiers=attack_classifiers)
return results

View file

@ -0,0 +1,104 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""An example for using keras_evaluation."""
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
flags.DEFINE_integer('batch_size', 256, 'Batch size')
flags.DEFINE_integer('epochs', 10, 'Number of epochs')
flags.DEFINE_string('model_dir', None, 'Model directory.')
def cnn_model():
"""Define a CNN model."""
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 8, strides=2, padding='same',
activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPool2D(2, 1),
tf.keras.layers.Conv2D(32, 4, strides=2, padding='valid',
activation='relu'),
tf.keras.layers.MaxPool2D(2, 1),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(10)
])
return model
def load_mnist():
"""Loads MNIST and preprocesses to combine training and validation data."""
(train_data,
train_labels), (test_data,
test_labels) = tf.keras.datasets.mnist.load_data()
train_data = np.array(train_data, dtype=np.float32) / 255
test_data = np.array(test_data, dtype=np.float32) / 255
train_data = train_data.reshape((train_data.shape[0], 28, 28, 1))
test_data = test_data.reshape((test_data.shape[0], 28, 28, 1))
train_labels = np.array(train_labels, dtype=np.int32)
test_labels = np.array(test_labels, dtype=np.int32)
return train_data, train_labels, test_data, test_labels
def main(unused_argv):
# Load training and test data.
train_data, train_labels, test_data, test_labels = load_mnist()
# Get model, optimizer and specify loss.
model = cnn_model()
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
# Get callback for membership inference attack.
mia_callback = MembershipInferenceCallback((train_data, train_labels),
(test_data, test_labels),
[],
FLAGS.model_dir)
# Train model with Keras
model.fit(train_data, train_labels,
epochs=FLAGS.epochs,
validation_data=(test_data, test_labels),
batch_size=FLAGS.batch_size,
callbacks=[mia_callback],
verbose=2)
print('End of training attack')
attack_results = run_attack_on_keras_model(model,
(train_data, train_labels),
(test_data, test_labels),
[])
print('all_thresh_loss_advantage',
attack_results['all_thresh_loss_advantage'])
if __name__ == '__main__':
app.run(main)

View file

@ -0,0 +1,72 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation."""
from absl.testing import absltest
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
class UtilsTest(absltest.TestCase):
def __init__(self, methodname):
"""Initialize the test class."""
super().__init__(methodname)
self.ntrain, self.ntest = 50, 100
self.nclass = 5
self.ndim = 10
# Generate random training and test data
self.train_data = np.random.rand(self.ntrain, self.ndim)
self.test_data = np.random.rand(self.ntest, self.ndim)
self.train_labels = np.random.randint(self.nclass, size=self.ntrain)
self.test_labels = np.random.randint(self.nclass, size=self.ntest)
self.model = tf.keras.Sequential([tf.keras.layers.Dense(self.nclass)])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
self.model.compile(optimizer='Adam', loss=loss, metrics=['accuracy'])
def test_calculate_losses(self):
"""Test calculating the loss."""
pred, loss = keras_evaluation.calculate_losses(self.model, self.train_data,
self.train_labels)
self.assertEqual(pred.shape, (self.ntrain, self.nclass))
self.assertEqual(loss.shape, (self.ntrain,))
pred, loss = keras_evaluation.calculate_losses(self.model, self.test_data,
self.test_labels)
self.assertEqual(pred.shape, (self.ntest, self.nclass))
self.assertEqual(loss.shape, (self.ntest,))
def test_run_attack_on_keras_model(self):
"""Test the attack."""
results = keras_evaluation.run_attack_on_keras_model(
self.model,
(self.train_data, self.train_labels),
(self.test_data, self.test_labels),
[])
self.assertIsInstance(results, dict)
self.assertIn('all_thresh_loss_auc', results)
self.assertIn('all_thresh_loss_advantage', results)
if __name__ == '__main__':
absltest.main()