From 20d0b884bab21010318dd677cd0bf0768c371849 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Tue, 29 Sep 2020 12:15:42 -0700 Subject: [PATCH] Move to new API. PiperOrigin-RevId: 334434385 --- .../membership_inference_attack/README.md | 371 +++++++++--------- .../membership_inference_attack/codelab.ipynb | 4 +- .../membership_inference_attack/example.py | 4 +- .../keras_evaluation.py | 2 +- ..._new.py => membership_inference_attack.py} | 0 ...py => membership_inference_attack_test.py} | 2 +- .../tf_estimator_evaluation.py | 2 +- 7 files changed, 186 insertions(+), 199 deletions(-) rename tensorflow_privacy/privacy/membership_inference_attack/{membership_inference_attack_new.py => membership_inference_attack.py} (100%) rename tensorflow_privacy/privacy/membership_inference_attack/{membership_inference_attack_new_test.py => membership_inference_attack_test.py} (98%) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/README.md b/tensorflow_privacy/privacy/membership_inference_attack/README.md index 9752194..27c95e7 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/README.md +++ b/tensorflow_privacy/privacy/membership_inference_attack/README.md @@ -18,17 +18,6 @@ the model are used (e.g., losses, logits, predictions). Neither model internals ## How to use - -### API revamp note -We're **revamping our attacks API to make it more structured, modular and -extensible**. The docs below refers to the legacy experimental API and will be -updated soon. Stay tuned! - -For a quick preview, you can take a look at `data_structures.py` and `membership_inference_attack_new.py`. - -For now, here's a reference to the legacy API. - - ### Codelab The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb). @@ -39,213 +28,213 @@ For a more detailed overview of the library, please check the sections below. ### Basic usage -On the highest level, there is the `run_all_attacks_and_create_summary` -function, which chooses sane default options to run a host of (fairly simple) -attacks behind the scenes (depending on which data is fed in), computes the most -important measures and returns a summary of the results as a string of english -language (as well as optionally a python dictionary containing all results with -descriptive keys). +The simplest possible usage is + +```python +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData + +# Suppose we have the labels as integers starting from 0 +# labels_train shape: (n_train, ) +# labels_test shape: (n_test, ) + +# Evaluate your model on training and test examples to get +# loss_train shape: (n_train, ) +# loss_test shape: (n_test, ) + +attacks_result = mia.run_attacks( + AttackInputData( + loss_train = loss_train, + loss_test = loss_test, + labels_train = labels_train, + labels_test = labels_test)) +``` + +This example calls `run_attacks` with the default options to run a host of +(fairly simple) attacks behind the scenes (depending on which data is fed in), +and computes the most important measures. > NOTE: The train and test sets are balanced internally, i.e., an equal number > of in-training and out-of-training examples is chosen for the attacks > (whichever has fewer examples). These are subsampled uniformly at random > without replacement from the larger of the two. -The simplest possible usage is +Then, we can view the attack results by: ```python -from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia - -# Evaluate your model on training and test examples to get -# loss_train shape: (n_train, ) -# loss_test shape: (n_test, ) - -summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, return_dict=True) -print(results) -# -> {'auc': 0.7044, -# 'best_attacker_auc': 'all_thresh_loss_auc', -# 'advantage': 0.3116, -# 'best_attacker_auc': 'all_thresh_loss_advantage'} +print(attacks_result.summary()) +# Example output: +# -> Best-performing attacks over all slices +# THRESHOLD_ATTACK achieved an AUC of 0.60 on slice Entire dataset +# THRESHOLD_ATTACK achieved an advantage of 0.22 on slice Entire dataset ``` -> NOTE: The keyword argument `return_dict` specified whether in addition to the -> `summary` the function also returns a python dictionary with the results. - -If the model is a classifier, the logits or output probabilities (i.e., the -softmax of logits) can also be provided to perform stronger attacks. - -> NOTE: The `logits_train` and `logits_test` arguments can also be filled with -> output probabilities per class ("posteriors"). - -```python -# logits_train shape: (n_train, n_classes) -# logits_test shape: (n_test, n_classes) - -summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train, - logits_test, return_dict=True) -print(results) -# -> {'auc': 0.5382, -# 'best_attacker_auc': 'all_lr_logits_loss_test_auc', -# 'advantage': 0.0572, -# 'best_attacker_auc': 'all_mlp_logits_loss_test_advantage'} -``` - -The `summary` will be a string in natural language describing the results in -more detail, e.g., - -``` -========== AUC ========== -The best attack (all_lr_logits_loss_test_auc) achieved an auc of 0.5382. - -========== ADVANTAGE ========== -The best attack (all_mlp_logits_loss_test_advantage) achieved an advantage of 0.0572. -``` - -Similarly, we can run attacks on the logits alone, without access to losses: - -```python -summary, results = mia.run_all_attacks_and_create_summary(logits_train=logits_train, - logits_test=logits_test, - return_dict=True) -print(results) -# -> {'auc': 0.9278, -# 'best_attacker_auc': 'all_rf_logits_test_auc', -# 'advantage': 0.6991, -# 'best_attacker_auc': 'all_rf_logits_test_advantage'} -``` ### Advanced usage -Finally, if we also have access to the true labels of the training and test -inputs, we can run the attacks for each class separately. If labels *and* logits -are provided, attacks only for misclassified (typically uncertain) examples are -also performed. +Sometimes, we have more information about the data, such as the logits and the +labels, +and we may want to have finer-grained control of the attack, such as using more +complicated classifiers instead of the simple threshold attack, and looks at the +attack results by examples' class. +In thoses cases, we can provide more information to `run_attacks`. ```python -summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train, - logits_test, labels_train, labels_test, - return_dict=True) +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType ``` -Here, we now also get as output the class with the maximal vulnerability -according to our metrics (`max_vuln_class_auc`, `max_vuln_class_advantage`) -together with the corresponding values (`class__auc`, -`class__advantage`). The same values exist in the `results` dictionary -for `min` instead of `max`, i.e., the least vulnerable classes. Moreover, the -gap between the maximum and minimum values (`max_class_gap_auc`, -`max_class_gap_advantage`) is also provided. Similarly, the vulnerability -metrics when the attacks are restricted to the misclassified examples -(`misclassified_auc`, `misclassified_advantage`) are also shown. Finally, the -results also contain the number of examples in each of these groups, i.e., -within each of the reported classes as well as the number of misclassified -examples. The final `results` dictionary is of the form - -``` -{'auc': 0.9181, - 'best_attacker_auc': 'all_rf_logits_loss_test_auc', - 'advantage': 0.6915, - 'best_attacker_advantage': 'all_rf_logits_loss_test_advantage', - 'max_class_gap_auc': 0.254, - 'class_5_auc': 0.9512, - 'class_3_auc': 0.6972, - 'max_vuln_class_auc': 5, - 'min_vuln_class_auc': 3, - 'max_class_gap_advantage': 0.5073, - 'class_0_advantage': 0.8086, - 'class_3_advantage': 0.3013, - 'max_vuln_class_advantage': 0, - 'min_vuln_class_advantage': 3, - 'misclassified_n_examples': 4513.0, - 'class_0_n_examples': 899.0, - 'class_1_n_examples': 900.0, - 'class_2_n_examples': 931.0, - 'class_3_n_examples': 893.0, - 'class_4_n_examples': 960.0, - 'class_5_n_examples': 884.0} -``` - -### Setting the precision of the reported results - -Finally, `run_all_attacks_and_create_summary` takes one extra keyword argument -`decimals`, expecting a positive integer. This sets the precision of all result -values as the number of decimals to report. It defaults to 4. - -## Run all attacks and get all outputs - -With the `run_all_attacks` function, one can run all implemented attacks on all -possible subsets of the data (all examples, split by class, split by confidence -deciles, misclassified only). This function returns a relatively large -dictionary with all attack results. This is the most detailed information one -could get about these types of membership inference attacks (besides plots for -each attack, see next section.) This is useful if you know exactly what you're -looking for. - -> NOTE: The `run_all_attacks` function takes as an additional argument which -> trained attackers to run. In the `run_all_attacks_and_create_summary`, only -> logistic regression (`lr`) is trained as a binary classifier to distinguish -> in-training form out-of-training examples. In addition, with the -> `attack_classifiers` argument, one can add multi-layered perceptrons (`mlp`), -> random forests (`rf`), and k-nearest-neighbors (`knn`) or any subset thereof -> for the attack models. Note that these classifiers may not converge. +First, similar as before, we specify the input for the attack as an +`AttackInputData` object: ```python -mia.run_all_attacks(loss_train, loss_test, logits_train, logits_test, - labels_train, labels_test, - attack_classifiers=('lr', 'mlp', 'rf', 'knn')) +# Evaluate your model on training and test examples to get +# logits_train shape: (n_train, n_classes) +# logits_test shape: (n_test, n_classes) +# loss_train shape: (n_train, ) +# loss_test shape: (n_test, ) + +attack_input = AttackInputData( + logits_train = logits_train, + logits_test = logits_test, + loss_train = loss_train, + loss_test = loss_test, + labels_train = labels_train, + labels_test = labels_test) ``` -Again, `run_all_attacks` can be called on all combinations of losses, logits, -probabilities, and labels as long as at least either losses or logits -(probabilities) are provided. +Instead of `logits`, you can also specify +`probs_train` and `probs_test` as the predicted probabilty vectors of each +example. -## Fine grained control over individual attacks and plots - -The `run_attack` function exposes the underlying workhorse of the -`run_all_attacks` and `run_all_attacks_and_create_summary` functionality. It -allows for fine grained control of which attacks to run individually. - -As another key feature, this function also exposes options to store receiver -operator curve plots for the different attacks as well as histograms of losses -or the maximum logits/probabilities. Finally, we can also store all results -(including the values to reproduce the plots) to colossus. - -All options are explained in detail in the doc string of the `run_attack` -function. - -For example, to run a simple threshold attack on the losses only and store plots -and result data to colossus, run +Then, we specify some details of the attack. +The first part includes the specifications of the slicing of the data. For +example, we may want to evaluate the result on the whole dataset, or by class, +percentiles, or the correctness of the model's classification. +These can be specified by a `SlicingSpec` object. ```python -data_path = '/Users/user/Desktop/test/' # set to None to not store data -figure_path = '/Users/user/Desktop/test/' # set to None to not store figures - -mia.attack(loss_train=loss_train, - loss_test=loss_test, - metric='auc', - output_directory=data_path, - figure_directory=figure_path) +slicing_spec = SlicingSpec( + entire_dataset = True, + by_class = True, + by_percentiles = False, + by_classification_correctness = True) ``` -Among other things, the `run_attack` functionality allows to control: +The second part specifies the classifiers for the attacker to use. +Currently, our API supports five classifiers, including +`AttackType.THRESHOLD_ATTACK` for simple threshold attack, +`AttackType.LOGISTIC_REGRESSION`, +`AttackType.MULTI_LAYERED_PERCEPTRON`, +`AttackType.RANDOM_FOREST`, and +`AttackType.K_NEAREST_NEIGHBORS` +which use the corresponding machine learning models. +For some model, different classifiers can yield pertty different results. +We can put multiple classifers in a list: + +```python +attack_types = [ + AttackType.THRESHOLD_ATTACK, + AttackType.LOGISTIC_REGRESSION +] +``` + +Now, we can call the `run_attacks` methods with all specifications: + +```python +attacks_result = mia.run_attacks(attack_input=attack_input, + slicing_spec=slicing_spec, + attack_types=attack_types) +``` + +This returns an object of type `AttackResults`. We can, for example, use the +following code to see the attack results specificed per-slice, as we have +request attacks by class and by model's classification correctness. + +```python +print(attacks_result.summary(by_slices = True)) +# Example output: +# -> Best-performing attacks over all slices +# THRESHOLD_ATTACK achieved an AUC of 0.75 on slice CORRECTLY_CLASSIFIED=False +# THRESHOLD_ATTACK achieved an advantage of 0.38 on slice CORRECTLY_CLASSIFIED=False +# +# Best-performing attacks over slice: "Entire dataset" +# LOGISTIC_REGRESSION achieved an AUC of 0.61 +# THRESHOLD_ATTACK achieved an advantage of 0.22 +# +# Best-performing attacks over slice: "CLASS=0" +# LOGISTIC_REGRESSION achieved an AUC of 0.62 +# LOGISTIC_REGRESSION achieved an advantage of 0.24 +# +# Best-performing attacks over slice: "CLASS=1" +# LOGISTIC_REGRESSION achieved an AUC of 0.61 +# LOGISTIC_REGRESSION achieved an advantage of 0.19 +# +# ... +# +# Best-performing attacks over slice: "CORRECTLY_CLASSIFIED=True" +# LOGISTIC_REGRESSION achieved an AUC of 0.53 +# THRESHOLD_ATTACK achieved an advantage of 0.05 +# +# Best-performing attacks over slice: "CORRECTLY_CLASSIFIED=False" +# THRESHOLD_ATTACK achieved an AUC of 0.75 +# THRESHOLD_ATTACK achieved an advantage of 0.38 +``` + + +### Viewing and plotting the attack results + +We have seen an example of using `summary()` to view the attack results as text. +We also provide some other ways for inspecting the attack results. + +To get the attack that achieves the maximum attacker advantage or AUC, we can do + +```python +max_auc_attacker = attacks_result.get_result_with_max_auc() +max_advantage_attacker = attacks_result.get_result_with_max_attacker_advantage() +``` +Then, for individual attack, such as `max_auc_attacker`, we can check its type, +attacker advantage and AUC by + +```python +print("Attack type with max AUC: %s, AUC of %.2f, Attacker advantage of %.2f" % + (max_auc_attacker.attack_type, + max_auc_attacker.roc_curve.get_auc(), + max_auc_attacker.roc_curve.get_attacker_advantage())) +# Example output: +# -> Attack type with max AUC: THRESHOLD_ATTACK, AUC of 0.75, Attacker advantage of 0.38 +``` +We can also plot its ROC curve by + +```python +import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting + +figure = plotting.plot_roc_curve(max_auc_attacker.roc_curve) +``` + +Additionally, we provide funcitonality to convert the attack results into Pandas +data frame: + +```python +import pandas as pd + +pd.set_option("display.max_rows", 8, "display.max_columns", None) +print(attacks_result.calculate_pd_dataframe()) +# Example output: +# slice feature slice value attack type Attacker advantage AUC +# 0 entire_dataset threshold 0.216440 0.600630 +# 1 entire_dataset lr 0.212073 0.612989 +# 2 class 0 threshold 0.226000 0.611669 +# 3 class 0 lr 0.239452 0.624076 +# .. ... ... ... ... ... +# 22 correctly_classfied True threshold 0.054907 0.471290 +# 23 correctly_classfied True lr 0.046986 0.525194 +# 24 correctly_classfied False threshold 0.379465 0.748138 +# 25 correctly_classfied False lr 0.370713 0.737148 +``` -* which metrics to output (`metric` argument, using `auc` or `advantage` or - both) -* which classifiers (logistic regression, multi-layered perceptrons, random - forests) to train as attackers beyond the simple threshold attacks - (`attack_classifiers`) -* to only attack a specific (set of) classes (`by_class`) -* to only attack specific percentiles of the data (`by_percentile`). - Percentiles here are computed by looking at the largest logit or probability - for each example, i.e., how confident the model is in its prediction. -* to only attack the misclassified examples (`only_misclassified`) -* not to balance examples between the in-training and out-of-training examples - using `balance`. By default an equal number of examples from train and test - are selected for the attacks (whichever is smaller). -* the test set size for trained attacks (`test_size`). When a classifier is - trained to distinguish between train and test examples, a train-test split - for that classifier itself is required. -* for the train-test split as well as for the class balancing randomness is - used with a seed specified by `random_state`. ## Contact / Feedback diff --git a/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb b/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb index 1bc1639..e9f4cb6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb +++ b/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb @@ -133,7 +133,7 @@ "source": [ "!pip3 install git+https://github.com/tensorflow/privacy\n", "\n", - "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia" + "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia" ] }, { @@ -163,7 +163,6 @@ "num_classes = 10\n", "num_conv = 3\n", "activation = 'relu'\n", - "optimizer = 'adam'\n", "lr = 0.02\n", "momentum = 0.9\n", "batch_size = 250\n", @@ -218,7 +217,6 @@ "model = small_cnn(\n", " input_shape, num_classes, num_conv=num_conv, activation=activation)\n", "\n", - "print('Optimizer ', optimizer)\n", "print('learning rate %f', lr)\n", "\n", "optimizer = tf.keras.optimizers.SGD(lr=lr, momentum=momentum)\n", diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index 21c8884..ddbc0c6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -27,7 +27,7 @@ from sklearn import metrics from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.utils import to_categorical -from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults @@ -185,7 +185,7 @@ for attack_result in attack_results.single_attack_results: print("Attacker advantage: %.2f\n" % attack_result.roc_curve.get_attacker_advantage()) -max_auc_attacker = attack_results.get_result_with_max_attacker_advantage() +max_auc_attacker = attack_results.get_result_with_max_auc() print("Attack type with max AUC: %s, AUC of %.2f" % (max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc())) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py index 9bbab25..5f504cb 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py @@ -21,7 +21,7 @@ from absl import logging import tensorflow.compat.v1 as tf -from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py similarity index 100% rename from tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py rename to tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py similarity index 98% rename from tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py rename to tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index 5e695b9..a56aa3a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -16,7 +16,7 @@ """Tests for tensorflow_privacy.privacy.membership_inference_attack.utils.""" from absl.testing import absltest import numpy as np -from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py index 0f5bae8..abf727f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py @@ -20,7 +20,7 @@ from typing import Iterable from absl import logging import numpy as np import tensorflow.compat.v1 as tf -from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics