Implement the membership inference attach using a keras-callback.

PiperOrigin-RevId: 389741018
This commit is contained in:
Mark Daoust 2021-08-09 15:38:17 -07:00 committed by A. Unique TensorFlower
parent f3af24b00e
commit b19e0b197a

View file

@ -95,7 +95,6 @@
"from sklearn import metrics\n", "from sklearn import metrics\n",
"\n", "\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"tf.compat.v1.disable_v2_behavior()\n",
"\n", "\n",
"import tensorflow_datasets as tfds\n", "import tensorflow_datasets as tfds\n",
"\n", "\n",
@ -137,14 +136,25 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia\n", "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData\n", "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection\n", "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType\n", "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric\n", "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata\n", "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec\n", "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec\n",
"from tensorflow_privacy.privacy.membership_inference_attack import privacy_report" "from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VpOdtnbPbPXE"
},
"outputs": [],
"source": [
"import tensorflow_privacy"
] ]
}, },
{ {
@ -171,13 +181,13 @@
"dataset = 'cifar10'\n", "dataset = 'cifar10'\n",
"num_classes = 10\n", "num_classes = 10\n",
"activation = 'relu'\n", "activation = 'relu'\n",
"lr = 0.02\n", "num_conv = 3\n",
"momentum = 0.9\n", "\n",
"batch_size = 250\n", "batch_size=50\n",
"epochs_per_report = 5\n", "epochs_per_report = 2\n",
"num_reports = 10\n", "total_epochs = 50\n",
"# Privacy risks are especially visible with lots of epochs.\n", "\n",
"total_epochs = epochs_per_report*num_reports " "lr = 0.001"
] ]
}, },
{ {
@ -197,7 +207,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Load the data\n", "#@title\n",
"print('Loading the dataset.')\n", "print('Loading the dataset.')\n",
"train_ds = tfds.as_numpy(\n", "train_ds = tfds.as_numpy(\n",
" tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))\n", " tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))\n",
@ -212,7 +222,9 @@
"y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)\n", "y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)\n",
"y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)\n", "y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)\n",
"\n", "\n",
"input_shape = x_train.shape[1:]" "input_shape = x_train.shape[1:]\n",
"\n",
"assert x_train.shape[0] % batch_size == 0, \"The tensorflow_privacy optimizer doesn't handle partial batches\""
] ]
}, },
{ {
@ -232,7 +244,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Define the models\n", "#@title\n",
"def small_cnn(input_shape: Tuple[int],\n", "def small_cnn(input_shape: Tuple[int],\n",
" num_classes: int,\n", " num_classes: int,\n",
" num_conv: int,\n", " num_conv: int,\n",
@ -259,7 +271,13 @@
" model.add(tf.keras.layers.Flatten())\n", " model.add(tf.keras.layers.Flatten())\n",
" model.add(tf.keras.layers.Dense(64, activation=activation))\n", " model.add(tf.keras.layers.Dense(64, activation=activation))\n",
" model.add(tf.keras.layers.Dense(num_classes))\n", " model.add(tf.keras.layers.Dense(num_classes))\n",
" return model\n" " \n",
" model.compile(\n",
" loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=lr),\n",
" metrics=['accuracy'])\n",
"\n",
" return model"
] ]
}, },
{ {
@ -268,7 +286,9 @@
"id": "hs0Smn24Dty-" "id": "hs0Smn24Dty-"
}, },
"source": [ "source": [
"Build two-layer and a three-layer CNN models using that function. Again there's nothing provacy specific about this code. It uses standard models, layers, losses, and optimizers." "Build two three-layer CNN models using that function.\n",
"\n",
"Configure the first to use a basic SGD optimizer, an the second to use a differentially private optimizer (`tf_privacy.DPKerasAdamOptimizer`), so you can compare the results."
] ]
}, },
{ {
@ -279,16 +299,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"optimizer = tf.keras.optimizers.SGD(lr=lr, momentum=momentum)\n", "model_2layers = small_cnn(\n",
"loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n",
"\n",
"three_layer_model = small_cnn(\n",
" input_shape, num_classes, num_conv=3, activation=activation)\n",
"three_layer_model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])\n",
"\n",
"two_layer_model = small_cnn(\n",
" input_shape, num_classes, num_conv=2, activation=activation)\n", " input_shape, num_classes, num_conv=2, activation=activation)\n",
"two_layer_model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])" "model_3layers = small_cnn(\n",
" input_shape, num_classes, num_conv=3, activation=activation)"
] ]
}, },
{ {
@ -318,42 +332,42 @@
" def __init__(self, epochs_per_report, model_name):\n", " def __init__(self, epochs_per_report, model_name):\n",
" self.epochs_per_report = epochs_per_report\n", " self.epochs_per_report = epochs_per_report\n",
" self.model_name = model_name\n", " self.model_name = model_name\n",
" self.epochs = []\n",
" self.attack_results = []\n", " self.attack_results = []\n",
"\n", "\n",
" def on_epoch_end(self, n, logs=None):\n", " def on_epoch_end(self, epoch, logs=None):\n",
" epoch = n + 1\n", " epoch = epoch+1\n",
"\n",
" if epoch % self.epochs_per_report != 0:\n", " if epoch % self.epochs_per_report != 0:\n",
" return\n", " return\n",
"\n", "\n",
" print(f\"\\nRunning privacy report for epoch: {epoch}\")\n", " print(f'\\nRunning privacy report for epoch: {epoch}\\n')\n",
" self.epochs.append(epoch)\n",
"\n", "\n",
" logits_train = model.predict(x_train, batch_size=batch_size)\n", " logits_train = self.model.predict(x_train, batch_size=batch_size)\n",
" logits_test = model.predict(x_test, batch_size=batch_size)\n", " logits_test = self.model.predict(x_test, batch_size=batch_size)\n",
"\n", "\n",
" prob_train = special.softmax(logits_train, axis=1)\n", " prob_train = special.softmax(logits_train, axis=1)\n",
" prob_test = special.softmax(logits_test, axis=1)\n", " prob_test = special.softmax(logits_test, axis=1)\n",
"\n", "\n",
" # Add metadata to generate a privacy report.\n", " # Add metadata to generate a privacy report.\n",
" privacy_report_metadata = PrivacyReportMetadata(\n", " privacy_report_metadata = PrivacyReportMetadata(\n",
" accuracy_train=metrics.accuracy_score(y_train_indices,\n", " # Show the validation accuracy on the plot\n",
" np.argmax(prob_train, axis=1)),\n", " # It's what you send to train_accuracy that gets plotted.\n",
" accuracy_test=metrics.accuracy_score(y_test_indices,\n", " accuracy_train=logs['val_accuracy'], \n",
" np.argmax(prob_test, axis=1)),\n", " accuracy_test=logs['val_accuracy'],\n",
" epoch_num=epoch,\n", " epoch_num=epoch,\n",
" model_variant_label=self.model_name)\n", " model_variant_label=self.model_name)\n",
"\n", "\n",
" attack_results = mia.run_attacks(\n", " attack_results = mia.run_attacks(\n",
" AttackInputData(\n", " AttackInputData(\n",
" labels_train=np.asarray([x[0] for x in y_train_indices]),\n", " labels_train=y_train_indices[:, 0],\n",
" labels_test=np.asarray([x[0] for x in y_test_indices]),\n", " labels_test=y_test_indices[:, 0],\n",
" probs_train=prob_train,\n", " probs_train=prob_train,\n",
" probs_test=prob_test),\n", " probs_test=prob_test),\n",
" SlicingSpec(entire_dataset=True, by_class=True),\n", " SlicingSpec(entire_dataset=True, by_class=True),\n",
" attack_types=(AttackType.THRESHOLD_ATTACK,\n", " attack_types=(AttackType.THRESHOLD_ATTACK,\n",
" AttackType.LOGISTIC_REGRESSION),\n", " AttackType.LOGISTIC_REGRESSION),\n",
" privacy_report_metadata=privacy_report_metadata)\n", " privacy_report_metadata=privacy_report_metadata)\n",
"\n",
" self.attack_results.append(attack_results)\n" " self.attack_results.append(attack_results)\n"
] ]
}, },
@ -368,6 +382,17 @@
"The next code block trains the two models. The `all_reports` list is used to collect all the results from all the models' training runs. The individual reports are tagged witht the `model_name`, so there's no confusion about which model generated which report." "The next code block trains the two models. The `all_reports` list is used to collect all the results from all the models' training runs. The individual reports are tagged witht the `model_name`, so there's no confusion about which model generated which report."
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o3U76c2Y4irD"
},
"outputs": [],
"source": [
"all_reports = []"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -376,19 +401,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"all_reports = []\n", "callback = PrivacyMetrics(epochs_per_report, \"2 Layers\")\n",
"\n", "history = model_2layers.fit(\n",
"models = {\n",
" 'two layer model': two_layer_model,\n",
" 'three layer model': three_layer_model,\n",
"}\n",
"\n",
"for model_name, model in models.items():\n",
" print(f\"\\n\\n\\nFitting {model_name}\\n\")\n",
" callback = PrivacyMetrics(epochs_per_report, \n",
" model_name)\n",
"\n",
" model.fit(\n",
" x_train,\n", " x_train,\n",
" y_train,\n", " y_train,\n",
" batch_size=batch_size,\n", " batch_size=batch_size,\n",
@ -397,7 +411,28 @@
" callbacks=[callback],\n", " callbacks=[callback],\n",
" shuffle=True)\n", " shuffle=True)\n",
"\n", "\n",
" all_reports.extend(callback.attack_results)\n" "all_reports.extend(callback.attack_results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "27qLElOR4y_i"
},
"outputs": [],
"source": [
"callback = PrivacyMetrics(epochs_per_report, \"3 Layers\")\n",
"history = model_3layers.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size=batch_size,\n",
" epochs=total_epochs,\n",
" validation_data=(x_test, y_test),\n",
" callbacks=[callback],\n",
" shuffle=True)\n",
"\n",
"all_reports.extend(callback.attack_results)"
] ]
}, },
{ {
@ -470,7 +505,10 @@
"source": [ "source": [
"privacy_metrics = (PrivacyMetric.AUC, PrivacyMetric.ATTACKER_ADVANTAGE)\n", "privacy_metrics = (PrivacyMetric.AUC, PrivacyMetric.ATTACKER_ADVANTAGE)\n",
"utility_privacy_plot = privacy_report.plot_privacy_vs_accuracy(\n", "utility_privacy_plot = privacy_report.plot_privacy_vs_accuracy(\n",
" results, privacy_metrics=privacy_metrics)" " results, privacy_metrics=privacy_metrics)\n",
"\n",
"for axis in utility_privacy_plot.axes:\n",
" axis.set_xlabel('Validation accuracy')"
] ]
}, },
{ {
@ -490,8 +528,7 @@
"id": "7u3BAg87v3qv" "id": "7u3BAg87v3qv"
}, },
"source": [ "source": [
"This is the end of the colab!\n", "This is the end of the tutorial. Feel free to analyze your own results."
"Feel free to analyze your own results."
] ]
} }
], ],
@ -500,6 +537,7 @@
"colab": { "colab": {
"collapsed_sections": [], "collapsed_sections": [],
"name": "privacy_report.ipynb", "name": "privacy_report.ipynb",
"provenance": [],
"toc_visible": true "toc_visible": true
}, },
"kernelspec": { "kernelspec": {