A training hook and a function to be called in the end of training for tf estimator to perform membership inference attack.

PiperOrigin-RevId: 321648371
This commit is contained in:
Shuang Song 2020-07-16 14:38:39 -07:00 committed by A. Unique TensorFlower
parent 51eb7c3712
commit a0e1b72838
5 changed files with 497 additions and 0 deletions

View file

@ -0,0 +1,166 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""A hook and a function in tf estimator for membership inference attack."""
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
def calculate_losses(estimator, input_fn, labels):
"""Get predictions and losses for samples.
The assumptions are 1) the loss is cross-entropy loss, and 2) user have
specified prediction mode to return predictions, e.g.,
when mode == tf.estimator.ModeKeys.PREDICT, the model function returns
tf.estimator.EstimatorSpec(mode=mode, predictions=tf.nn.softmax(logits)).
Args:
estimator: model to make prediction
input_fn: input function to be used in estimator.predict
labels: true labels of samples
Returns:
preds: probability vector of each sample
loss: cross entropy loss of each sample
"""
pred = np.array(list(estimator.predict(input_fn=input_fn)))
loss = log_loss(labels, pred)
return pred, loss
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
"""Training hook to perform membership inference attack after an epoch."""
def __init__(self,
estimator,
in_train,
out_train,
input_fn_constructor,
attack_classifiers,
writer=None):
"""Initalizes the hook.
Args:
estimator: model to be tested
in_train: (in_training samples, in_training labels)
out_train: (out_training samples, out_training labels)
input_fn_constructor: a function that receives sample, label and construct
the input_fn for model prediction
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
writer: summary writer for tensorboard
"""
in_train_data, self._in_train_labels = in_train
out_train_data, self._out_train_labels = out_train
# Define the input functions for both in and out-training samples.
self._in_train_input_fn = input_fn_constructor(in_train_data,
self._in_train_labels)
self._out_train_input_fn = input_fn_constructor(out_train_data,
self._out_train_labels)
self._estimator = estimator
self._attack_classifiers = attack_classifiers
self._writer = writer
if self._writer:
logging.info('Will write to tensorboard.')
def end(self, session):
results = run_attack_helper(self._estimator,
self._in_train_input_fn,
self._out_train_input_fn,
self._in_train_labels, self._out_train_labels,
self._attack_classifiers)
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
logging.info(results)
if self._writer:
summary = tf.Summary()
summary.value.add(tag='attack advantage',
simple_value=results['all_thresh_loss_advantage'])
global_step = self._estimator.get_variable_value('global_step')
self._writer.add_summary(summary, global_step)
self._writer.flush()
def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
input_fn_constructor, attack_classifiers):
"""A function to perform the attack in the end of training.
Args:
estimator: model to be tested
in_train: (in_training samples, in_training labels)
out_train: (out_training samples, out_training labels)
input_fn_constructor: a function that receives sample, label and construct
the input_fn for model prediction
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
Returns:
Results of the attack
"""
in_train_data, in_train_labels = in_train
out_train_data, out_train_labels = out_train
# Define the input functions for both in and out-training samples.
in_train_input_fn = input_fn_constructor(in_train_data, in_train_labels)
out_train_input_fn = input_fn_constructor(out_train_data, out_train_labels)
# Call the helper to run the attack.
results = run_attack_helper(estimator,
in_train_input_fn, out_train_input_fn,
in_train_labels, out_train_labels,
attack_classifiers)
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
logging.info('End of training attack:')
logging.info(results)
return results
def run_attack_helper(estimator,
in_train_input_fn, out_train_input_fn,
in_train_labels, out_train_labels,
attack_classifiers):
"""A helper function to perform attack.
Args:
estimator: model to be tested
in_train_input_fn: input_fn for in training data
out_train_input_fn: input_fn for out of training data
in_train_labels: in training labels
out_train_labels: out of training labels
attack_classifiers: a list of classifiers to be used by attacker, must be
a subset of ['lr', 'mlp', 'rf', 'knn']
Returns:
Results of the attack
"""
# Compute predictions and losses
in_train_pred, in_train_loss = calculate_losses(estimator,
in_train_input_fn,
in_train_labels)
out_train_pred, out_train_loss = calculate_losses(estimator,
out_train_input_fn,
out_train_labels)
results = mia.run_all_attacks(in_train_loss, out_train_loss,
in_train_pred, out_train_pred,
in_train_labels, out_train_labels,
attack_classifiers=attack_classifiers)
return results

View file

@ -0,0 +1,162 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""An example for using tf_estimator_evaluation."""
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
flags.DEFINE_integer('batch_size', 256, 'Batch size')
flags.DEFINE_integer('epochs', 10, 'Number of epochs')
flags.DEFINE_string('model_dir', None, 'Model directory.')
def cnn_model_fn(features, labels, mode):
"""Model function for a CNN."""
# Define CNN architecture using tf.keras.layers.
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
y = tf.keras.layers.Conv2D(
16, 8, strides=2, padding='same', activation='relu').apply(input_layer)
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
y = tf.keras.layers.Conv2D(
32, 4, strides=2, padding='valid', activation='relu').apply(y)
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
y = tf.keras.layers.Flatten().apply(y)
y = tf.keras.layers.Dense(32, activation='relu').apply(y)
logits = tf.keras.layers.Dense(10).apply(y)
if mode != tf.estimator.ModeKeys.PREDICT:
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
# Configure the training op (for TRAIN mode).
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=scalar_loss,
train_op=train_op)
# Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = {
'accuracy':
tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(input=logits, axis=1))
}
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops)
# Output the prediction probability (for PREDICT mode).
elif mode == tf.estimator.ModeKeys.PREDICT:
predictions = tf.nn.softmax(logits)
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
def load_mnist():
"""Loads MNIST and preprocesses to combine training and validation data."""
(train_data,
train_labels), (test_data,
test_labels) = tf.keras.datasets.mnist.load_data()
train_data = np.array(train_data, dtype=np.float32) / 255
test_data = np.array(test_data, dtype=np.float32) / 255
train_labels = np.array(train_labels, dtype=np.int32)
test_labels = np.array(test_labels, dtype=np.int32)
return train_data, train_labels, test_data, test_labels
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
logging.set_verbosity(logging.INFO)
logging.set_stderrthreshold(logging.INFO)
logging.get_absl_handler().use_absl_log_file()
# Load training and test data.
train_data, train_labels, test_data, test_labels = load_mnist()
# Instantiate the tf.Estimator.
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
model_dir=FLAGS.model_dir)
# A function to construct input_fn given (data, label), to be used by the
# membership inference training hook.
def input_fn_constructor(x, y):
return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False)
with tf.Graph().as_default():
# Get a summary writer for the hook to write to tensorboard.
# Can set summary_writer to None if not needed.
if FLAGS.model_dir:
summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
else:
summary_writer = None
mia_hook = MembershipInferenceTrainingHook(mnist_classifier,
(train_data, train_labels),
(test_data, test_labels),
input_fn_constructor,
[],
summary_writer)
# Create tf.Estimator input functions for the training and test data.
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': train_data},
y=train_labels,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.epochs,
shuffle=True)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False)
# Training loop.
steps_per_epoch = 60000 // FLAGS.batch_size
for epoch in range(1, FLAGS.epochs + 1):
# Train the model, with the membership inference hook.
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch,
hooks=[mia_hook])
# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
test_accuracy = eval_results['accuracy']
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
print('End of training attack')
run_attack_on_tf_estimator_model(mnist_classifier,
(train_data, train_labels),
(test_data, test_labels),
input_fn_constructor,
['lr'])
if __name__ == '__main__':
app.run(main)

View file

@ -0,0 +1,107 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation."""
from absl.testing import absltest
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
class UtilsTest(absltest.TestCase):
def __init__(self, methodname):
"""Initialize the test class."""
super().__init__(methodname)
self.ntrain, self.ntest = 50, 100
self.nclass = 5
self.ndim = 10
# Generate random training and test data
self.train_data = np.random.rand(self.ntrain, self.ndim)
self.test_data = np.random.rand(self.ntest, self.ndim)
self.train_labels = np.random.randint(self.nclass, size=self.ntrain)
self.test_labels = np.random.randint(self.nclass, size=self.ntest)
# Define a simple model function
def model_fn(features, labels, mode):
"""Model function for logistic regression."""
del labels
input_layer = tf.reshape(features['x'], [-1, self.ndim])
logits = tf.keras.layers.Dense(self.nclass).apply(input_layer)
# Define the PREDICT mode becasue we only need that
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = tf.nn.softmax(logits)
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Define the classifier, input_fn for training and test data
self.classifier = tf.estimator.Estimator(model_fn=model_fn)
self.input_fn_train = tf.estimator.inputs.numpy_input_fn(
x={'x': self.train_data}, y=self.train_labels, num_epochs=1,
shuffle=False)
self.input_fn_test = tf.estimator.inputs.numpy_input_fn(
x={'x': self.test_data}, y=self.test_labels, num_epochs=1,
shuffle=False)
def test_calculate_losses(self):
"""Test calculating the loss."""
pred, loss = tf_estimator_evaluation.calculate_losses(self.classifier,
self.input_fn_train,
self.train_labels)
self.assertEqual(pred.shape, (self.ntrain, self.nclass))
self.assertEqual(loss.shape, (self.ntrain,))
pred, loss = tf_estimator_evaluation.calculate_losses(self.classifier,
self.input_fn_test,
self.test_labels)
self.assertEqual(pred.shape, (self.ntest, self.nclass))
self.assertEqual(loss.shape, (self.ntest,))
def test_run_attack_helper(self):
"""Test the attack."""
results = tf_estimator_evaluation.run_attack_helper(self.classifier,
self.input_fn_train,
self.input_fn_test,
self.train_labels,
self.test_labels,
[])
self.assertIsInstance(results, dict)
self.assertIn('all_thresh_loss_auc', results)
self.assertIn('all_thresh_loss_advantage', results)
def test_run_attack_on_tf_estimator_model(self):
"""Test the attack on the final models."""
def input_fn_constructor(x, y):
return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False)
results = tf_estimator_evaluation.run_attack_on_tf_estimator_model(
self.classifier,
(self.train_data, self.train_labels),
(self.test_data, self.test_labels),
input_fn_constructor,
[])
self.assertIsInstance(results, dict)
self.assertIn('all_thresh_loss_auc', results)
self.assertIn('all_thresh_loss_advantage', results)
if __name__ == '__main__':
absltest.main()

View file

@ -216,3 +216,23 @@ def compute_performance_metrics(true_labels: np.ndarray,
'advantage': advantage,
})
return ensure_1d(results)
# ------------------------------------------------------------------------------
# Loss functions
# ------------------------------------------------------------------------------
def log_loss(y, pred, small_value=1e-8):
"""Compute the cross entropy loss.
Args:
y: numpy array, y[i] is the true label (scalar) of the i-th sample
pred: numpy array, pred[i] is the probability vector of the i-th sample
small_value: np.log can become -inf if the probability is too close to 0,
so the probability is clipped below by small_value.
Returns:
the cross-entropy loss of each sample
"""
return -np.log(np.maximum(pred[range(y.size), y], small_value))

View file

@ -100,6 +100,48 @@ class UtilsTest(absltest.TestCase):
self.assertEqual(x_test.shape, (n_test, 11))
self.assertEqual(y_test.shape, (n_test,))
def test_log_loss(self):
"""Test computing cross-entropy loss."""
# Test binary case with a few normal values
pred = np.array([[0.01, 0.99], [0.1, 0.9], [0.25, 0.75], [0.5, 0.5],
[0.75, 0.25], [0.9, 0.1], [0.99, 0.01]])
# Test the cases when true label (for all samples) is 0 and 1
expected_losses = {
0: np.array([4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
0.10536052, 0.01005034]),
1: np.array([0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
2.30258509, 4.60517019])
}
for c in [0, 1]: # true label
y = np.ones(shape=pred.shape[0], dtype=int) * c
loss = utils.log_loss(y, pred)
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
# Test multiclass case with a few normal values
# (values from http://bit.ly/RJJHWA)
pred = np.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3],
[0.99, 0.002, 0.008]])
# Test the cases when true label (for all samples) is 0, 1, and 2
expected_losses = {
0: np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034]),
1: np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081]),
2: np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374])
}
for c in range(3): # true label
y = np.ones(shape=pred.shape[0], dtype=int) * c
loss = utils.log_loss(y, pred)
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
# Test boundary values 0 and 1
pred = np.array([[0, 1]] * 2)
y = np.array([0, 1])
small_values = [1e-8, 1e-20, 1e-50]
expected_losses = np.array([18.42068074, 46.05170186, 115.12925465])
for i, small_value in enumerate(small_values):
loss = utils.log_loss(y, pred, small_value)
np.testing.assert_allclose(loss, np.array([expected_losses[i], 0]),
atol=1e-7)
if __name__ == '__main__':
absltest.main()