From e547a10eec8df1efc2ba40724067ce2bfd68b54a Mon Sep 17 00:00:00 2001 From: Liwei Song Date: Mon, 31 Aug 2020 15:24:46 -0400 Subject: [PATCH 1/3] fix softmax issue --- .../membership_inference_attack/codelab.ipynb | 771 +++++++++--------- 1 file changed, 392 insertions(+), 379 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb b/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb index bb24786..1bc1639 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb +++ b/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb @@ -1,382 +1,395 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1eiwVljWpzM7" - }, - "source": [ - "Copyright 2020 The TensorFlow Authors.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": {}, - "colab_type": "code", - "id": "4rmwPgXeptiS" - }, - "outputs": [], - "source": [ - "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "YM2gRaJMqvMi" - }, - "source": [ - "# Assess privacy risks with TensorFlow Privacy Membership Inference Attacks" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-B5ZvlSqqLaR" - }, - "source": [ - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", - " \u003c/td\u003e\n", - "\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "9rMuytY7Nn8P" - }, - "source": [ - "##Overview\n", - "In this codelab we'll train a simple image classification model on the CIFAR10 dataset, and then use the \"membership inference attack\" against this model to assess if the attacker is able to \"guess\" whether a particular sample was present in the training set." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FUWqArj_q8vs" - }, - "source": [ - "## Setup\n", - "First, set this notebook's runtime to use a GPU, under Runtime \u003e Change runtime type \u003e Hardware accelerator. Then, begin importing the necessary libraries." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "Lr1pwHcbralz" - }, - "outputs": [], - "source": [ - "#@title Import statements.\n", - "import numpy as np\n", - "from typing import Tuple, Text\n", - "from scipy import special\n", - "\n", - "import tensorflow as tf\n", - "import tensorflow_datasets as tfds\n", - "\n", - "# Set verbosity.\n", - "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", - "from warnings import simplefilter\n", - "from sklearn.exceptions import ConvergenceWarning\n", - "simplefilter(action=\"ignore\", category=ConvergenceWarning)\n", - "simplefilter(action=\"ignore\", category=FutureWarning)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ucw81ar6ru-6" - }, - "source": [ - "### Install TensorFlow Privacy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": {}, - "colab_type": "code", - "id": "zcqAmiGH90kl" - }, - "outputs": [], - "source": [ - "!pip3 install git+https://github.com/tensorflow/privacy\n", - "\n", - "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "pBbcG86th_sW" - }, - "source": [ - "## Train a model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "vCyOWyyhXLib" - }, - "outputs": [], - "source": [ - "#@markdown Train a simple model on CIFAR10 with Keras.\n", - "\n", - "dataset = 'cifar10'\n", - "num_classes = 10\n", - "num_conv = 3\n", - "activation = 'relu'\n", - "optimizer = 'adam'\n", - "lr = 0.02\n", - "momentum = 0.9\n", - "batch_size = 250\n", - "epochs = 100 # Privacy risks are especially visible with lots of epochs.\n", - "\n", - "\n", - "def small_cnn(input_shape: Tuple[int],\n", - " num_classes: int,\n", - " num_conv: int,\n", - " activation: Text = 'relu') -\u003e tf.keras.models.Sequential:\n", - " \"\"\"Setup a small CNN for image classification.\n", - "\n", - " Args:\n", - " input_shape: Integer tuple for the shape of the images.\n", - " num_classes: Number of prediction classes.\n", - " num_conv: Number of convolutional layers.\n", - " activation: The activation function to use for conv and dense layers.\n", - "\n", - " Returns:\n", - " The Keras model.\n", - " \"\"\"\n", - " model = tf.keras.models.Sequential()\n", - " model.add(tf.keras.layers.Input(shape=input_shape))\n", - "\n", - " # Conv layers\n", - " for _ in range(num_conv):\n", - " model.add(tf.keras.layers.Conv2D(32, (3, 3), activation=activation))\n", - " model.add(tf.keras.layers.MaxPooling2D())\n", - "\n", - " model.add(tf.keras.layers.Flatten())\n", - " model.add(tf.keras.layers.Dense(64, activation=activation))\n", - " model.add(tf.keras.layers.Dense(num_classes))\n", - " return model\n", - "\n", - "\n", - "print('Loading the dataset.')\n", - "train_ds = tfds.as_numpy(\n", - " tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))\n", - "test_ds = tfds.as_numpy(\n", - " tfds.load(dataset, split=tfds.Split.TEST, batch_size=-1))\n", - "x_train = train_ds['image'].astype('float32') / 255.\n", - "y_train_indices = train_ds['label'][:, np.newaxis]\n", - "x_test = test_ds['image'].astype('float32') / 255.\n", - "y_test_indices = test_ds['label'][:, np.newaxis]\n", - "\n", - "# Convert class vectors to binary class matrices.\n", - "y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)\n", - "y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)\n", - "\n", - "input_shape = x_train.shape[1:]\n", - "\n", - "model = small_cnn(\n", - " input_shape, num_classes, num_conv=num_conv, activation=activation)\n", - "\n", - "print('Optimizer ', optimizer)\n", - "print('learning rate %f', lr)\n", - "\n", - "optimizer = tf.keras.optimizers.SGD(lr=lr, momentum=momentum)\n", - "\n", - "loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n", - "model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])\n", - "model.summary()\n", - "model.fit(\n", - " x_train,\n", - " y_train,\n", - " batch_size=batch_size,\n", - " epochs=epochs,\n", - " validation_data=(x_test, y_test),\n", - " shuffle=True)\n", - "print('Finished training.')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ee-zjGGGV1DC" - }, - "source": [ - "## Calculate logits, probabilities and loss values for training and test sets.\n", - "\n", - "We will use these values later in the membership inference attack to separate training and test samples." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": {}, - "colab_type": "code", - "id": "um9r0tSiPx4u" - }, - "outputs": [], - "source": [ - "print('Predict on train...')\n", - "logits_train = model.predict(x_train, batch_size=batch_size)\n", - "print('Predict on test...')\n", - "logits_test = model.predict(x_test, batch_size=batch_size)\n", - "\n", - "print('Apply softmax to get probabilities from logits...')\n", - "prob_train = special.softmax(logits_train)\n", - "prob_test = special.softmax(logits_test)\n", - "\n", - "print('Compute losses...')\n", - "cce = tf.keras.backend.categorical_crossentropy\n", - "constant = tf.keras.backend.constant\n", - "\n", - "loss_train = cce(constant(y_train), constant(prob_train), from_logits=False).numpy()\n", - "loss_test = cce(constant(y_test), constant(prob_test), from_logits=False).numpy()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "QETxVOHLiHP4" - }, - "source": [ - "## Run membership inference attacks.\n", - "\n", - "We will now execute a membership inference attack against the previously trained CIFAR10 model. This will generate a number of scores, most notably, attacker advantage and AUC for the membership inference classifier.\n", - "\n", - "An AUC of close to 0.5 means that the attack wasn't able to identify training samples, which means that the model doesn't have privacy issues according to this test. Higher values, on the contrary, indicate potential privacy issues." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "B8NIwhVwQT7I" - }, - "outputs": [], - "source": [ - "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData\n", - "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec\n", - "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType\n", - "\n", - "import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting\n", - "\n", - "labels_train = np.argmax(y_train, axis=1)\n", - "labels_test = np.argmax(y_test, axis=1)\n", - "\n", - "input = AttackInputData(\n", - " logits_train = logits_train,\n", - " logits_test = logits_test,\n", - " loss_train = loss_train,\n", - " loss_test = loss_test,\n", - " labels_train = labels_train,\n", - " labels_test = labels_test\n", - ")\n", - "\n", - "# Run several attacks for different data slices\n", - "attacks_result = mia.run_attacks(input,\n", - " SlicingSpec(\n", - " entire_dataset = True,\n", - " by_class = True,\n", - " by_classification_correctness = True\n", - " ),\n", - " attack_types = [\n", - " AttackType.THRESHOLD_ATTACK,\n", - " AttackType.LOGISTIC_REGRESSION])\n", - "\n", - "# Plot the ROC curve of the best classifier\n", - "fig = plotting.plot_roc_curve(\n", - " attacks_result.get_result_with_max_auc().roc_curve)\n", - "\n", - "# Print a user-friendly summary of the attacks\n", - "print(attacks_result.summary(by_slices = True))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "E9zwsPGFujVq" - }, - "source": [ - "This is the end of the codelab!\n", - "Feel free to change the parameters to see how the privacy risks change.\n", - "\n", - "You can try playing with:\n", - "* the number of training epochs\n", - "* different attack_types" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "last_runtime": { - "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Membership inference codelab", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1eiwVljWpzM7" + }, + "source": [ + "Copyright 2020 The TensorFlow Authors.\n" + ] }, - "nbformat": 4, - "nbformat_minor": 0 + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "4rmwPgXeptiS" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YM2gRaJMqvMi" + }, + "source": [ + "# Assess privacy risks with TensorFlow Privacy Membership Inference Attacks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-B5ZvlSqqLaR" + }, + "source": [ + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9rMuytY7Nn8P" + }, + "source": [ + "##Overview\n", + "In this codelab we'll train a simple image classification model on the CIFAR10 dataset, and then use the \"membership inference attack\" against this model to assess if the attacker is able to \"guess\" whether a particular sample was present in the training set." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FUWqArj_q8vs" + }, + "source": [ + "## Setup\n", + "First, set this notebook's runtime to use a GPU, under Runtime > Change runtime type > Hardware accelerator. Then, begin importing the necessary libraries." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "Lr1pwHcbralz" + }, + "outputs": [], + "source": [ + "#@title Import statements.\n", + "import numpy as np\n", + "from typing import Tuple, Text\n", + "from scipy import special\n", + "\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "\n", + "# Set verbosity.\n", + "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", + "from warnings import simplefilter\n", + "from sklearn.exceptions import ConvergenceWarning\n", + "simplefilter(action=\"ignore\", category=ConvergenceWarning)\n", + "simplefilter(action=\"ignore\", category=FutureWarning)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ucw81ar6ru-6" + }, + "source": [ + "### Install TensorFlow Privacy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "zcqAmiGH90kl" + }, + "outputs": [], + "source": [ + "!pip3 install git+https://github.com/tensorflow/privacy\n", + "\n", + "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "pBbcG86th_sW" + }, + "source": [ + "## Train a model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "vCyOWyyhXLib" + }, + "outputs": [], + "source": [ + "#@markdown Train a simple model on CIFAR10 with Keras.\n", + "\n", + "dataset = 'cifar10'\n", + "num_classes = 10\n", + "num_conv = 3\n", + "activation = 'relu'\n", + "optimizer = 'adam'\n", + "lr = 0.02\n", + "momentum = 0.9\n", + "batch_size = 250\n", + "epochs = 100 # Privacy risks are especially visible with lots of epochs.\n", + "\n", + "\n", + "def small_cnn(input_shape: Tuple[int],\n", + " num_classes: int,\n", + " num_conv: int,\n", + " activation: Text = 'relu') -> tf.keras.models.Sequential:\n", + " \"\"\"Setup a small CNN for image classification.\n", + "\n", + " Args:\n", + " input_shape: Integer tuple for the shape of the images.\n", + " num_classes: Number of prediction classes.\n", + " num_conv: Number of convolutional layers.\n", + " activation: The activation function to use for conv and dense layers.\n", + "\n", + " Returns:\n", + " The Keras model.\n", + " \"\"\"\n", + " model = tf.keras.models.Sequential()\n", + " model.add(tf.keras.layers.Input(shape=input_shape))\n", + "\n", + " # Conv layers\n", + " for _ in range(num_conv):\n", + " model.add(tf.keras.layers.Conv2D(32, (3, 3), activation=activation))\n", + " model.add(tf.keras.layers.MaxPooling2D())\n", + "\n", + " model.add(tf.keras.layers.Flatten())\n", + " model.add(tf.keras.layers.Dense(64, activation=activation))\n", + " model.add(tf.keras.layers.Dense(num_classes))\n", + " return model\n", + "\n", + "\n", + "print('Loading the dataset.')\n", + "train_ds = tfds.as_numpy(\n", + " tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))\n", + "test_ds = tfds.as_numpy(\n", + " tfds.load(dataset, split=tfds.Split.TEST, batch_size=-1))\n", + "x_train = train_ds['image'].astype('float32') / 255.\n", + "y_train_indices = train_ds['label'][:, np.newaxis]\n", + "x_test = test_ds['image'].astype('float32') / 255.\n", + "y_test_indices = test_ds['label'][:, np.newaxis]\n", + "\n", + "# Convert class vectors to binary class matrices.\n", + "y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)\n", + "y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)\n", + "\n", + "input_shape = x_train.shape[1:]\n", + "\n", + "model = small_cnn(\n", + " input_shape, num_classes, num_conv=num_conv, activation=activation)\n", + "\n", + "print('Optimizer ', optimizer)\n", + "print('learning rate %f', lr)\n", + "\n", + "optimizer = tf.keras.optimizers.SGD(lr=lr, momentum=momentum)\n", + "\n", + "loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n", + "model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])\n", + "model.summary()\n", + "model.fit(\n", + " x_train,\n", + " y_train,\n", + " batch_size=batch_size,\n", + " epochs=epochs,\n", + " validation_data=(x_test, y_test),\n", + " shuffle=True)\n", + "print('Finished training.')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ee-zjGGGV1DC" + }, + "source": [ + "## Calculate logits, probabilities and loss values for training and test sets.\n", + "\n", + "We will use these values later in the membership inference attack to separate training and test samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "um9r0tSiPx4u" + }, + "outputs": [], + "source": [ + "print('Predict on train...')\n", + "logits_train = model.predict(x_train, batch_size=batch_size)\n", + "print('Predict on test...')\n", + "logits_test = model.predict(x_test, batch_size=batch_size)\n", + "\n", + "print('Apply softmax to get probabilities from logits...')\n", + "prob_train = special.softmax(logits_train, axis=1)\n", + "prob_test = special.softmax(logits_test, axis=1)\n", + "\n", + "print('Compute losses...')\n", + "cce = tf.keras.backend.categorical_crossentropy\n", + "constant = tf.keras.backend.constant\n", + "\n", + "loss_train = cce(constant(y_train), constant(prob_train), from_logits=False).numpy()\n", + "loss_test = cce(constant(y_test), constant(prob_test), from_logits=False).numpy()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QETxVOHLiHP4" + }, + "source": [ + "## Run membership inference attacks.\n", + "\n", + "We will now execute a membership inference attack against the previously trained CIFAR10 model. This will generate a number of scores, most notably, attacker advantage and AUC for the membership inference classifier.\n", + "\n", + "An AUC of close to 0.5 means that the attack wasn't able to identify training samples, which means that the model doesn't have privacy issues according to this test. Higher values, on the contrary, indicate potential privacy issues." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "B8NIwhVwQT7I" + }, + "outputs": [], + "source": [ + "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData\n", + "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec\n", + "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType\n", + "\n", + "import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting\n", + "\n", + "labels_train = np.argmax(y_train, axis=1)\n", + "labels_test = np.argmax(y_test, axis=1)\n", + "\n", + "input = AttackInputData(\n", + " logits_train = logits_train,\n", + " logits_test = logits_test,\n", + " loss_train = loss_train,\n", + " loss_test = loss_test,\n", + " labels_train = labels_train,\n", + " labels_test = labels_test\n", + ")\n", + "\n", + "# Run several attacks for different data slices\n", + "attacks_result = mia.run_attacks(input,\n", + " SlicingSpec(\n", + " entire_dataset = True,\n", + " by_class = True,\n", + " by_classification_correctness = True\n", + " ),\n", + " attack_types = [\n", + " AttackType.THRESHOLD_ATTACK,\n", + " AttackType.LOGISTIC_REGRESSION])\n", + "\n", + "# Plot the ROC curve of the best classifier\n", + "fig = plotting.plot_roc_curve(\n", + " attacks_result.get_result_with_max_auc().roc_curve)\n", + "\n", + "# Print a user-friendly summary of the attacks\n", + "print(attacks_result.summary(by_slices = True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "E9zwsPGFujVq" + }, + "source": [ + "This is the end of the codelab!\n", + "Feel free to change the parameters to see how the privacy risks change.\n", + "\n", + "You can try playing with:\n", + "* the number of training epochs\n", + "* different attack_types" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", + "kind": "private" + }, + "name": "Membership inference codelab", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.10" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 } From 9b2e6a55b68ec5685c03d839761e37727d6a12e0 Mon Sep 17 00:00:00 2001 From: Liwei Song Date: Mon, 31 Aug 2020 16:17:19 -0400 Subject: [PATCH 2/3] add entropy feature --- .../data_structures.py | 63 ++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 082a90b..e323b2f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -22,6 +22,7 @@ from dataclasses import dataclass import numpy as np import pandas as pd from sklearn import metrics +from scipy import special ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)' @@ -144,6 +145,9 @@ def _is_np_array(arr, arr_name): if arr is not None and not isinstance(arr, np.ndarray): raise ValueError('%s should be a numpy array.' % arr_name) +def _log_value(probs, small_value=1e-30): + """Compute the log value on the probability. Clip the probability in case it is close to 0""" + return -np.log(np.maximum(probs, small_value)) @dataclass class AttackInputData: @@ -165,6 +169,11 @@ class AttackInputData: loss_train: np.ndarray = None loss_test: np.ndarray = None + # Explicitly specified prediction entropy. If provided, this is used instead of deriving + # entropy from logits and labels (https://arxiv.org/pdf/2003.10595.pdf by Song and Mittal) + entropy_train: np.ndarray = None + entropy_test: np.ndarray = None + @property def num_classes(self): if self.labels_train is None or self.labels_test is None: @@ -177,6 +186,34 @@ class AttackInputData: def _get_loss(logits: np.ndarray, true_labels: np.ndarray): return logits[range(logits.shape[0]), true_labels] + @staticmethod + def _get_entropy(logits: np.ndarray, true_labels: np.ndarray): + if (np.absolute(np.sum(logits,axis=1)-1)<=1e-3).all(): + probs = logits + else: + """Using softmax to compute probability from logits""" + probs = special.softmax(logits, axis=1) + if true_labels is None: + ''' + When not given ground truth label, we compute the normal prediction entropy. + See the Equation (7) in https://arxiv.org/pdf/2003.10595.pdf + ''' + return np.sum(np.multiply(probs, _log_value(probs)),axis=1) + else: + ''' + When given the groud truth label, we compute the modified prediction entropy. + See the Equation (8) in https://arxiv.org/pdf/2003.10595.pdf + ''' + log_probs = _log_value(probs) + reverse_probs = 1-probs + log_reverse_probs = _log_value(reverse_probs) + modified_probs = np.copy(probs) + modified_probs[range(true_labels.size), true_labels] = reverse_probs[range(true_labels.size), true_labels] + modified_log_probs = np.copy(log_reverse_probs) + modified_log_probs[range(true_labels.size), true_labels] = log_probs[range(true_labels.size), true_labels] + return np.sum(np.multiply(modified_probs, modified_log_probs),axis=1) + + def get_loss_train(self): """Calculates cross-entropy losses for the training set.""" if self.loss_train is not None: @@ -189,6 +226,18 @@ class AttackInputData: return self.loss_test return self._get_loss(self.logits_test, self.labels_test) + def get_entropy_train(self): + """Calculates prediction entropy for the training set.""" + if self.entropy_train is not None: + return self.entropy_train + return self._get_entropy(self.logits_train, self.labels_train) + + def get_entropy_test(self): + """Calculates prediction entropy for the test set.""" + if self.entropy_test is not None: + return self.entropy_test + return self._get_entropy(self.logits_test, self.labels_test) + def get_train_size(self): """Returns size of the training set.""" if self.loss_train is not None: @@ -206,6 +255,10 @@ class AttackInputData: if (self.loss_train is None) != (self.loss_test is None): raise ValueError( 'loss_test and loss_train should both be either set or unset') + + if (self.entropy_train is None) != (self.entropy_test is None): + raise ValueError( + 'entropy_test and entropy_train should both be either set or unset') if (self.logits_train is None) != (self.logits_test is None): raise ValueError( @@ -216,8 +269,8 @@ class AttackInputData: 'labels_train and labels_test should both be either set or unset') if (self.labels_train is None and self.loss_train is None and - self.logits_train is None): - raise ValueError('At least one of labels, logits or losses should be set') + self.logits_train is None and self.entropy_train is None): + raise ValueError('At least one of labels, logits, losses or entropy should be set') if self.labels_train is not None and not _is_integer_type_array( self.labels_train): @@ -233,11 +286,15 @@ class AttackInputData: _is_np_array(self.labels_test, 'labels_test') _is_np_array(self.loss_train, 'loss_train') _is_np_array(self.loss_test, 'loss_test') + _is_np_array(self.entropy_train, 'entropy_train') + _is_np_array(self.entropy_test, 'entropy_test') _is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test, 'logits_test') _is_array_one_dimensional(self.loss_train, 'loss_train') _is_array_one_dimensional(self.loss_test, 'loss_test') + _is_array_one_dimensional(self.entropy_train, 'entropy_train') + _is_array_one_dimensional(self.entropy_test, 'entropy_test') _is_array_one_dimensional(self.labels_train, 'labels_train') _is_array_one_dimensional(self.labels_test, 'labels_test') @@ -246,6 +303,8 @@ class AttackInputData: result = ['AttackInputData('] _append_array_shape(self.loss_train, 'loss_train', result) _append_array_shape(self.loss_test, 'loss_test', result) + _append_array_shape(self.entropy_train, 'entropy_train', result) + _append_array_shape(self.entropy_test, 'entropy_test', result) _append_array_shape(self.logits_train, 'logits_train', result) _append_array_shape(self.logits_test, 'logits_test', result) _append_array_shape(self.labels_train, 'labels_train', result) From 0e1c1eeef392ffe816f8e24cbb652b6f280cdc99 Mon Sep 17 00:00:00 2001 From: Liwei Song Date: Wed, 2 Sep 2020 11:37:12 -0400 Subject: [PATCH 3/3] add entropy tests --- .../data_structures_test.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 8314d95..199d32a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import numpy as np import pandas as pd from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve @@ -65,6 +66,33 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_equal(attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + def test_get_entropy(self): + attack_input = AttackInputData( + logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), + labels_train=np.array([0, 2]), + labels_test=np.array([0, 2])) + + np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0]) + np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [2*_log_value(0), 0]) + + attack_input = AttackInputData( + logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) + + np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0]) + np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [0, 0]) + + def test_get_entropy_explicitly_provided(self): + attack_input = AttackInputData( + entropy_train=np.array([0.0, 2.0, 1.0]), + entropy_test=np.array([0.5, 3.0, 5.0])) + + np.testing.assert_equal(attack_input.get_entropy_train().tolist(), + [0.0, 2.0, 1.0]) + np.testing.assert_equal(attack_input.get_entropy_test().tolist(), + [0.5, 3.0, 5.0]) + def test_validator(self): self.assertRaises(ValueError, AttackInputData(logits_train=np.array([])).validate) @@ -72,12 +100,16 @@ class AttackInputDataTest(absltest.TestCase): AttackInputData(labels_train=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(loss_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(entropy_train=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(logits_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(labels_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(loss_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(entropy_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData().validate)