diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py new file mode 100644 index 0000000..5566410 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py @@ -0,0 +1,166 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""A hook and a function in tf estimator for membership inference attack.""" + +from absl import logging + +import numpy as np + +import tensorflow.compat.v1 as tf + +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss + + +def calculate_losses(estimator, input_fn, labels): + """Get predictions and losses for samples. + + The assumptions are 1) the loss is cross-entropy loss, and 2) user have + specified prediction mode to return predictions, e.g., + when mode == tf.estimator.ModeKeys.PREDICT, the model function returns + tf.estimator.EstimatorSpec(mode=mode, predictions=tf.nn.softmax(logits)). + + Args: + estimator: model to make prediction + input_fn: input function to be used in estimator.predict + labels: true labels of samples + + Returns: + preds: probability vector of each sample + loss: cross entropy loss of each sample + """ + pred = np.array(list(estimator.predict(input_fn=input_fn))) + loss = log_loss(labels, pred) + return pred, loss + + +class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): + """Training hook to perform membership inference attack after an epoch.""" + + def __init__(self, + estimator, + in_train, + out_train, + input_fn_constructor, + attack_classifiers, + writer=None): + """Initalizes the hook. + + Args: + estimator: model to be tested + in_train: (in_training samples, in_training labels) + out_train: (out_training samples, out_training labels) + input_fn_constructor: a function that receives sample, label and construct + the input_fn for model prediction + attack_classifiers: a list of classifiers to be used by attacker, must be + a subset of ['lr', 'mlp', 'rf', 'knn'] + writer: summary writer for tensorboard + """ + in_train_data, self._in_train_labels = in_train + out_train_data, self._out_train_labels = out_train + + # Define the input functions for both in and out-training samples. + self._in_train_input_fn = input_fn_constructor(in_train_data, + self._in_train_labels) + self._out_train_input_fn = input_fn_constructor(out_train_data, + self._out_train_labels) + self._estimator = estimator + self._attack_classifiers = attack_classifiers + self._writer = writer + if self._writer: + logging.info('Will write to tensorboard.') + + def end(self, session): + results = run_attack_helper(self._estimator, + self._in_train_input_fn, + self._out_train_input_fn, + self._in_train_labels, self._out_train_labels, + self._attack_classifiers) + print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) + logging.info(results) + + if self._writer: + summary = tf.Summary() + summary.value.add(tag='attack advantage', + simple_value=results['all_thresh_loss_advantage']) + global_step = self._estimator.get_variable_value('global_step') + self._writer.add_summary(summary, global_step) + self._writer.flush() + + +def run_attack_on_tf_estimator_model(estimator, in_train, out_train, + input_fn_constructor, attack_classifiers): + """A function to perform the attack in the end of training. + + Args: + estimator: model to be tested + in_train: (in_training samples, in_training labels) + out_train: (out_training samples, out_training labels) + input_fn_constructor: a function that receives sample, label and construct + the input_fn for model prediction + attack_classifiers: a list of classifiers to be used by attacker, must be + a subset of ['lr', 'mlp', 'rf', 'knn'] + Returns: + Results of the attack + """ + in_train_data, in_train_labels = in_train + out_train_data, out_train_labels = out_train + + # Define the input functions for both in and out-training samples. + in_train_input_fn = input_fn_constructor(in_train_data, in_train_labels) + out_train_input_fn = input_fn_constructor(out_train_data, out_train_labels) + + # Call the helper to run the attack. + results = run_attack_helper(estimator, + in_train_input_fn, out_train_input_fn, + in_train_labels, out_train_labels, + attack_classifiers) + print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) + logging.info('End of training attack:') + logging.info(results) + return results + + +def run_attack_helper(estimator, + in_train_input_fn, out_train_input_fn, + in_train_labels, out_train_labels, + attack_classifiers): + """A helper function to perform attack. + + Args: + estimator: model to be tested + in_train_input_fn: input_fn for in training data + out_train_input_fn: input_fn for out of training data + in_train_labels: in training labels + out_train_labels: out of training labels + attack_classifiers: a list of classifiers to be used by attacker, must be + a subset of ['lr', 'mlp', 'rf', 'knn'] + Returns: + Results of the attack + """ + # Compute predictions and losses + in_train_pred, in_train_loss = calculate_losses(estimator, + in_train_input_fn, + in_train_labels) + out_train_pred, out_train_loss = calculate_losses(estimator, + out_train_input_fn, + out_train_labels) + results = mia.run_all_attacks(in_train_loss, out_train_loss, + in_train_pred, out_train_pred, + in_train_labels, out_train_labels, + attack_classifiers=attack_classifiers) + return results + diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py new file mode 100644 index 0000000..94e7183 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py @@ -0,0 +1,162 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""An example for using tf_estimator_evaluation.""" + +from absl import app +from absl import flags +from absl import logging + +import numpy as np +import tensorflow.compat.v1 as tf + +from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook +from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model + +GradientDescentOptimizer = tf.train.GradientDescentOptimizer + +FLAGS = flags.FLAGS + +flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') +flags.DEFINE_integer('batch_size', 256, 'Batch size') +flags.DEFINE_integer('epochs', 10, 'Number of epochs') +flags.DEFINE_string('model_dir', None, 'Model directory.') + + +def cnn_model_fn(features, labels, mode): + """Model function for a CNN.""" + + # Define CNN architecture using tf.keras.layers. + input_layer = tf.reshape(features['x'], [-1, 28, 28, 1]) + y = tf.keras.layers.Conv2D( + 16, 8, strides=2, padding='same', activation='relu').apply(input_layer) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Conv2D( + 32, 4, strides=2, padding='valid', activation='relu').apply(y) + y = tf.keras.layers.MaxPool2D(2, 1).apply(y) + y = tf.keras.layers.Flatten().apply(y) + y = tf.keras.layers.Dense(32, activation='relu').apply(y) + logits = tf.keras.layers.Dense(10).apply(y) + + if mode != tf.estimator.ModeKeys.PREDICT: + vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + scalar_loss = tf.reduce_mean(input_tensor=vector_loss) + + # Configure the training op (for TRAIN mode). + if mode == tf.estimator.ModeKeys.TRAIN: + optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step) + return tf.estimator.EstimatorSpec( + mode=mode, + loss=scalar_loss, + train_op=train_op) + + # Add evaluation metrics (for EVAL mode). + elif mode == tf.estimator.ModeKeys.EVAL: + eval_metric_ops = { + 'accuracy': + tf.metrics.accuracy( + labels=labels, predictions=tf.argmax(input=logits, axis=1)) + } + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops) + + # Output the prediction probability (for PREDICT mode). + elif mode == tf.estimator.ModeKeys.PREDICT: + predictions = tf.nn.softmax(logits) + return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) + + +def load_mnist(): + """Loads MNIST and preprocesses to combine training and validation data.""" + (train_data, + train_labels), (test_data, + test_labels) = tf.keras.datasets.mnist.load_data() + + train_data = np.array(train_data, dtype=np.float32) / 255 + test_data = np.array(test_data, dtype=np.float32) / 255 + + train_labels = np.array(train_labels, dtype=np.int32) + test_labels = np.array(test_labels, dtype=np.int32) + + return train_data, train_labels, test_data, test_labels + + +def main(unused_argv): + tf.logging.set_verbosity(tf.logging.INFO) + logging.set_verbosity(logging.INFO) + logging.set_stderrthreshold(logging.INFO) + logging.get_absl_handler().use_absl_log_file() + + # Load training and test data. + train_data, train_labels, test_data, test_labels = load_mnist() + + # Instantiate the tf.Estimator. + mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, + model_dir=FLAGS.model_dir) + + # A function to construct input_fn given (data, label), to be used by the + # membership inference training hook. + def input_fn_constructor(x, y): + return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False) + + with tf.Graph().as_default(): + # Get a summary writer for the hook to write to tensorboard. + # Can set summary_writer to None if not needed. + if FLAGS.model_dir: + summary_writer = tf.summary.FileWriter(FLAGS.model_dir) + else: + summary_writer = None + mia_hook = MembershipInferenceTrainingHook(mnist_classifier, + (train_data, train_labels), + (test_data, test_labels), + input_fn_constructor, + [], + summary_writer) + + # Create tf.Estimator input functions for the training and test data. + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': train_data}, + y=train_labels, + batch_size=FLAGS.batch_size, + num_epochs=FLAGS.epochs, + shuffle=True) + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False) + + # Training loop. + steps_per_epoch = 60000 // FLAGS.batch_size + for epoch in range(1, FLAGS.epochs + 1): + # Train the model, with the membership inference hook. + mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch, + hooks=[mia_hook]) + + # Evaluate the model and print results + eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) + test_accuracy = eval_results['accuracy'] + print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy)) + + print('End of training attack') + run_attack_on_tf_estimator_model(mnist_classifier, + (train_data, train_labels), + (test_data, test_labels), + input_fn_constructor, + ['lr']) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py new file mode 100644 index 0000000..fc73843 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py @@ -0,0 +1,107 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation.""" + +from absl.testing import absltest + +import numpy as np +import tensorflow.compat.v1 as tf + +from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation + + +class UtilsTest(absltest.TestCase): + + def __init__(self, methodname): + """Initialize the test class.""" + super().__init__(methodname) + + self.ntrain, self.ntest = 50, 100 + self.nclass = 5 + self.ndim = 10 + + # Generate random training and test data + self.train_data = np.random.rand(self.ntrain, self.ndim) + self.test_data = np.random.rand(self.ntest, self.ndim) + self.train_labels = np.random.randint(self.nclass, size=self.ntrain) + self.test_labels = np.random.randint(self.nclass, size=self.ntest) + + # Define a simple model function + def model_fn(features, labels, mode): + """Model function for logistic regression.""" + del labels + + input_layer = tf.reshape(features['x'], [-1, self.ndim]) + logits = tf.keras.layers.Dense(self.nclass).apply(input_layer) + + # Define the PREDICT mode becasue we only need that + if mode == tf.estimator.ModeKeys.PREDICT: + predictions = tf.nn.softmax(logits) + return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) + + # Define the classifier, input_fn for training and test data + self.classifier = tf.estimator.Estimator(model_fn=model_fn) + self.input_fn_train = tf.estimator.inputs.numpy_input_fn( + x={'x': self.train_data}, y=self.train_labels, num_epochs=1, + shuffle=False) + self.input_fn_test = tf.estimator.inputs.numpy_input_fn( + x={'x': self.test_data}, y=self.test_labels, num_epochs=1, + shuffle=False) + + def test_calculate_losses(self): + """Test calculating the loss.""" + pred, loss = tf_estimator_evaluation.calculate_losses(self.classifier, + self.input_fn_train, + self.train_labels) + self.assertEqual(pred.shape, (self.ntrain, self.nclass)) + self.assertEqual(loss.shape, (self.ntrain,)) + + pred, loss = tf_estimator_evaluation.calculate_losses(self.classifier, + self.input_fn_test, + self.test_labels) + self.assertEqual(pred.shape, (self.ntest, self.nclass)) + self.assertEqual(loss.shape, (self.ntest,)) + + def test_run_attack_helper(self): + """Test the attack.""" + results = tf_estimator_evaluation.run_attack_helper(self.classifier, + self.input_fn_train, + self.input_fn_test, + self.train_labels, + self.test_labels, + []) + self.assertIsInstance(results, dict) + self.assertIn('all_thresh_loss_auc', results) + self.assertIn('all_thresh_loss_advantage', results) + + def test_run_attack_on_tf_estimator_model(self): + """Test the attack on the final models.""" + def input_fn_constructor(x, y): + return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False) + + results = tf_estimator_evaluation.run_attack_on_tf_estimator_model( + self.classifier, + (self.train_data, self.train_labels), + (self.test_data, self.test_labels), + input_fn_constructor, + []) + self.assertIsInstance(results, dict) + self.assertIn('all_thresh_loss_auc', results) + self.assertIn('all_thresh_loss_advantage', results) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/membership_inference_attack/utils.py index 08cf017..82e30e9 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils.py @@ -216,3 +216,23 @@ def compute_performance_metrics(true_labels: np.ndarray, 'advantage': advantage, }) return ensure_1d(results) + + +# ------------------------------------------------------------------------------ +# Loss functions +# ------------------------------------------------------------------------------ + + +def log_loss(y, pred, small_value=1e-8): + """Compute the cross entropy loss. + + Args: + y: numpy array, y[i] is the true label (scalar) of the i-th sample + pred: numpy array, pred[i] is the probability vector of the i-th sample + small_value: np.log can become -inf if the probability is too close to 0, + so the probability is clipped below by small_value. + + Returns: + the cross-entropy loss of each sample + """ + return -np.log(np.maximum(pred[range(y.size), y], small_value)) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py b/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py index 0b918a7..fd9fa31 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py @@ -100,6 +100,48 @@ class UtilsTest(absltest.TestCase): self.assertEqual(x_test.shape, (n_test, 11)) self.assertEqual(y_test.shape, (n_test,)) + def test_log_loss(self): + """Test computing cross-entropy loss.""" + # Test binary case with a few normal values + pred = np.array([[0.01, 0.99], [0.1, 0.9], [0.25, 0.75], [0.5, 0.5], + [0.75, 0.25], [0.9, 0.1], [0.99, 0.01]]) + # Test the cases when true label (for all samples) is 0 and 1 + expected_losses = { + 0: np.array([4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207, + 0.10536052, 0.01005034]), + 1: np.array([0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436, + 2.30258509, 4.60517019]) + } + for c in [0, 1]: # true label + y = np.ones(shape=pred.shape[0], dtype=int) * c + loss = utils.log_loss(y, pred) + np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7) + + # Test multiclass case with a few normal values + # (values from http://bit.ly/RJJHWA) + pred = np.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3], + [0.99, 0.002, 0.008]]) + # Test the cases when true label (for all samples) is 0, 1, and 2 + expected_losses = { + 0: np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034]), + 1: np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081]), + 2: np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374]) + } + for c in range(3): # true label + y = np.ones(shape=pred.shape[0], dtype=int) * c + loss = utils.log_loss(y, pred) + np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7) + + # Test boundary values 0 and 1 + pred = np.array([[0, 1]] * 2) + y = np.array([0, 1]) + small_values = [1e-8, 1e-20, 1e-50] + expected_losses = np.array([18.42068074, 46.05170186, 115.12925465]) + for i, small_value in enumerate(small_values): + loss = utils.log_loss(y, pred, small_value) + np.testing.assert_allclose(loss, np.array([expected_losses[i], 0]), + atol=1e-7) + if __name__ == '__main__': absltest.main()