Modify Colab to use the new membership inference API.

PiperOrigin-RevId: 327805944
This commit is contained in:
A. Unique TensorFlower 2020-08-21 07:21:12 -07:00
parent d23772e163
commit 7a77d5d92c

View file

@ -7,14 +7,14 @@
"id": "1eiwVljWpzM7" "id": "1eiwVljWpzM7"
}, },
"source": [ "source": [
"##### Copyright 2020 The TensorFlow Authors.\n" "Copyright 2020 The TensorFlow Authors.\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "both",
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "4rmwPgXeptiS" "id": "4rmwPgXeptiS"
@ -87,6 +87,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "Lr1pwHcbralz" "id": "Lr1pwHcbralz"
@ -116,7 +117,7 @@
"id": "ucw81ar6ru-6" "id": "ucw81ar6ru-6"
}, },
"source": [ "source": [
"Install TensorFlow Privacy." "### Install TensorFlow Privacy."
] ]
}, },
{ {
@ -132,7 +133,7 @@
"source": [ "source": [
"!pip3 install git+https://github.com/tensorflow/privacy\n", "!pip3 install git+https://github.com/tensorflow/privacy\n",
"\n", "\n",
"from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia" "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia"
] ]
}, },
{ {
@ -142,19 +143,22 @@
"id": "pBbcG86th_sW" "id": "pBbcG86th_sW"
}, },
"source": [ "source": [
"## Train a simple model on CIFAR10 with Keras." "## Train a model"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "vCyOWyyhXLib" "id": "vCyOWyyhXLib"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@markdown Train a simple model on CIFAR10 with Keras.\n",
"\n",
"dataset = 'cifar10'\n", "dataset = 'cifar10'\n",
"num_classes = 10\n", "num_classes = 10\n",
"num_conv = 3\n", "num_conv = 3\n",
@ -232,19 +236,29 @@
"print('Finished training.')" "print('Finished training.')"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ee-zjGGGV1DC"
},
"source": [
"## Calculate logits, probabilities and loss values for training and test sets.\n",
"\n",
"We will use these values later in the membership inference attack to separate training and test samples."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "both",
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "um9r0tSiPx4u" "id": "um9r0tSiPx4u"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Calculate logits, probabilities and loss values for training and test sets.\n",
"#@markdown We will use these values later in the membership inference attack to\n",
"#@markdown separate training and test samples.\n",
"print('Predict on train...')\n", "print('Predict on train...')\n",
"logits_train = model.predict(x_train, batch_size=batch_size)\n", "logits_train = model.predict(x_train, batch_size=batch_size)\n",
"print('Predict on test...')\n", "print('Predict on test...')\n",
@ -269,7 +283,11 @@
"id": "QETxVOHLiHP4" "id": "QETxVOHLiHP4"
}, },
"source": [ "source": [
"## Run membership inference attacks." "## Run membership inference attacks.\n",
"\n",
"We will now execute a membership inference attack against the previously trained CIFAR10 model. This will generate a number of scores, most notably, attacker advantage and AUC for the membership inference classifier.\n",
"\n",
"An AUC of close to 0.5 means that the attack wasn't able to identify training samples, which means that the model doesn't have privacy issues according to this test. Higher values, on the contrary, indicate potential privacy issues."
] ]
}, },
{ {
@ -282,41 +300,41 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@markdown We will now execute membership inference attack against the\n", "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData\n",
"#@markdown previously trained CIFAR10 model. This will generate a number of\n", "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec\n",
"#@markdown scores (most notably, attacker advantage and AUC for the membership\n", "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType\n",
"#@markdown inference classifier). An AUC of close to 0.5 means that the attack\n", "\n",
"#@markdown isn't able to identify training samples, which means that the model\n", "import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting\n",
"#@markdown doesn't have privacy issues according to this test. Higher values,\n",
"#@markdown on the contrary, indicate potential privacy issues.\n",
"\n", "\n",
"labels_train = np.argmax(y_train, axis=1)\n", "labels_train = np.argmax(y_train, axis=1)\n",
"labels_test = np.argmax(y_test, axis=1)\n", "labels_test = np.argmax(y_test, axis=1)\n",
"\n", "\n",
"results_without_classifiers = mia.run_all_attacks(\n", "input = AttackInputData(\n",
" loss_train,\n", " logits_train = logits_train,\n",
" loss_test,\n", " logits_test = logits_test,\n",
" logits_train,\n", " loss_train = loss_train,\n",
" logits_test,\n", " loss_test = loss_test,\n",
" labels_train,\n", " labels_train = labels_train,\n",
" labels_test,\n", " labels_test = labels_test\n",
" attack_classifiers=[],\n",
")\n", ")\n",
"print(results_without_classifiers)\n",
"\n", "\n",
"# Note: This will take a while, since it also trains ML models to\n", "# Run several attacks for different data slices\n",
"# separate train/test examples. If it's taking too looking, use\n", "attacks_result = mia.run_attacks(input,\n",
"# the `run_all_attacks` function instead.\n", " SlicingSpec(\n",
"attack_result_summary = mia.run_all_attacks_and_create_summary(\n", " entire_dataset = True,\n",
" loss_train,\n", " by_class = True,\n",
" loss_test,\n", " by_classification_correctness = True\n",
" logits_train,\n", " ),\n",
" logits_test,\n", " attack_types = [\n",
" labels_train,\n", " AttackType.THRESHOLD_ATTACK,\n",
" labels_test,\n", " AttackType.LOGISTIC_REGRESSION])\n",
")[0]\n",
"\n", "\n",
"print(attack_result_summary)" "# Plot the ROC curve of the best classifier\n",
"fig = plotting.plot_roc_curve(\n",
" attacks_result.get_result_with_max_auc().roc_curve)\n",
"\n",
"# Print a user-friendly summary of the attacks\n",
"print(attacks_result.summary(by_slices = True))"
] ]
}, },
{ {
@ -326,7 +344,12 @@
"id": "E9zwsPGFujVq" "id": "E9zwsPGFujVq"
}, },
"source": [ "source": [
"This is the end of the codelab! Feel free to change the parameters to see how the privacy risks change." "This is the end of the codelab!\n",
"Feel free to change the parameters to see how the privacy risks change.\n",
"\n",
"You can try playing with:\n",
"* the number of training epochs\n",
"* different attack_types"
] ]
} }
], ],