From 7a77d5d92cc9f5b84d4319069deaa35756f1420f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Aug 2020 07:21:12 -0700 Subject: [PATCH] Modify Colab to use the new membership inference API. PiperOrigin-RevId: 327805944 --- .../membership_inference_attack/codelab.ipynb | 99 ++++++++++++------- 1 file changed, 61 insertions(+), 38 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb b/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb index e774247..bb24786 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb +++ b/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb @@ -7,14 +7,14 @@ "id": "1eiwVljWpzM7" }, "source": [ - "##### Copyright 2020 The TensorFlow Authors.\n" + "Copyright 2020 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", + "cellView": "both", "colab": {}, "colab_type": "code", "id": "4rmwPgXeptiS" @@ -87,6 +87,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "colab": {}, "colab_type": "code", "id": "Lr1pwHcbralz" @@ -116,7 +117,7 @@ "id": "ucw81ar6ru-6" }, "source": [ - "Install TensorFlow Privacy." + "### Install TensorFlow Privacy." ] }, { @@ -132,7 +133,7 @@ "source": [ "!pip3 install git+https://github.com/tensorflow/privacy\n", "\n", - "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia" + "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia" ] }, { @@ -142,19 +143,22 @@ "id": "pBbcG86th_sW" }, "source": [ - "## Train a simple model on CIFAR10 with Keras." + "## Train a model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "colab": {}, "colab_type": "code", "id": "vCyOWyyhXLib" }, "outputs": [], "source": [ + "#@markdown Train a simple model on CIFAR10 with Keras.\n", + "\n", "dataset = 'cifar10'\n", "num_classes = 10\n", "num_conv = 3\n", @@ -232,19 +236,29 @@ "print('Finished training.')" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ee-zjGGGV1DC" + }, + "source": [ + "## Calculate logits, probabilities and loss values for training and test sets.\n", + "\n", + "We will use these values later in the membership inference attack to separate training and test samples." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "both", "colab": {}, "colab_type": "code", "id": "um9r0tSiPx4u" }, "outputs": [], "source": [ - "#@title Calculate logits, probabilities and loss values for training and test sets.\n", - "#@markdown We will use these values later in the membership inference attack to\n", - "#@markdown separate training and test samples.\n", "print('Predict on train...')\n", "logits_train = model.predict(x_train, batch_size=batch_size)\n", "print('Predict on test...')\n", @@ -269,7 +283,11 @@ "id": "QETxVOHLiHP4" }, "source": [ - "## Run membership inference attacks." + "## Run membership inference attacks.\n", + "\n", + "We will now execute a membership inference attack against the previously trained CIFAR10 model. This will generate a number of scores, most notably, attacker advantage and AUC for the membership inference classifier.\n", + "\n", + "An AUC of close to 0.5 means that the attack wasn't able to identify training samples, which means that the model doesn't have privacy issues according to this test. Higher values, on the contrary, indicate potential privacy issues." ] }, { @@ -282,41 +300,41 @@ }, "outputs": [], "source": [ - "#@markdown We will now execute membership inference attack against the\n", - "#@markdown previously trained CIFAR10 model. This will generate a number of\n", - "#@markdown scores (most notably, attacker advantage and AUC for the membership\n", - "#@markdown inference classifier). An AUC of close to 0.5 means that the attack\n", - "#@markdown isn't able to identify training samples, which means that the model\n", - "#@markdown doesn't have privacy issues according to this test. Higher values,\n", - "#@markdown on the contrary, indicate potential privacy issues.\n", + "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData\n", + "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec\n", + "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType\n", + "\n", + "import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting\n", "\n", "labels_train = np.argmax(y_train, axis=1)\n", "labels_test = np.argmax(y_test, axis=1)\n", "\n", - "results_without_classifiers = mia.run_all_attacks(\n", - " loss_train,\n", - " loss_test,\n", - " logits_train,\n", - " logits_test,\n", - " labels_train,\n", - " labels_test,\n", - " attack_classifiers=[],\n", + "input = AttackInputData(\n", + " logits_train = logits_train,\n", + " logits_test = logits_test,\n", + " loss_train = loss_train,\n", + " loss_test = loss_test,\n", + " labels_train = labels_train,\n", + " labels_test = labels_test\n", ")\n", - "print(results_without_classifiers)\n", "\n", - "# Note: This will take a while, since it also trains ML models to\n", - "# separate train/test examples. If it's taking too looking, use\n", - "# the `run_all_attacks` function instead.\n", - "attack_result_summary = mia.run_all_attacks_and_create_summary(\n", - " loss_train,\n", - " loss_test,\n", - " logits_train,\n", - " logits_test,\n", - " labels_train,\n", - " labels_test,\n", - ")[0]\n", + "# Run several attacks for different data slices\n", + "attacks_result = mia.run_attacks(input,\n", + " SlicingSpec(\n", + " entire_dataset = True,\n", + " by_class = True,\n", + " by_classification_correctness = True\n", + " ),\n", + " attack_types = [\n", + " AttackType.THRESHOLD_ATTACK,\n", + " AttackType.LOGISTIC_REGRESSION])\n", "\n", - "print(attack_result_summary)" + "# Plot the ROC curve of the best classifier\n", + "fig = plotting.plot_roc_curve(\n", + " attacks_result.get_result_with_max_auc().roc_curve)\n", + "\n", + "# Print a user-friendly summary of the attacks\n", + "print(attacks_result.summary(by_slices = True))" ] }, { @@ -326,7 +344,12 @@ "id": "E9zwsPGFujVq" }, "source": [ - "This is the end of the codelab! Feel free to change the parameters to see how the privacy risks change." + "This is the end of the codelab!\n", + "Feel free to change the parameters to see how the privacy risks change.\n", + "\n", + "You can try playing with:\n", + "* the number of training epochs\n", + "* different attack_types" ] } ],