diff --git a/tensorflow_privacy/privacy/membership_inference_attack/README.md b/tensorflow_privacy/privacy/membership_inference_attack/README.md new file mode 100644 index 0000000..1d81360 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/README.md @@ -0,0 +1,238 @@ +# Membership inference attack functionality + +The goal is to provide empirical tests of "how much information a machine +learning model has remembered about its training data". To this end, only the +outputs of the model are used (e.g., losses, logits, predictions). From those +alone, the attacks try to infer whether the corresponding inputs were part of +the training set. + +> NOTE: Only the loss values are needed for some examples used during training +> and some examples that have not been used during training (e.g., some examples +> from the test set). No access to actual input data is needed. In case of +> classification models, one can additionally (or instead of losses) provide +> logits or output probabilities for stronger attacks. + +The vulnerability of a model is measured via the area under the ROC-curve +(`auc`) or via max{|fpr - tpr|} (`advantage`) of the attack classifier. These +measures are very closely related. + +## Highest level -- get attack summary + +### Basic usage + +On the highest level, there is the `run_all_attacks_and_create_summary` +function, which chooses sane default options to run a host of (fairly simple) +attacks behind the scenes (depending on which data is fed in), computes the most +important measures and returns a summary of the results as a string of english +language (as well as optionally a python dictionary containing all results with +descriptive keys). + +> NOTE: The train and test sets are balanced internally, i.e., an equal number +> of in-training and out-of-training examples is chosen for the attacks +> (whichever has fewer examples). These are subsampled uniformly at random +> without replacement from the larger of the two. + +The simplest possible usage is + +```python +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia + +# Evaluate your model on training and test examples to get +# loss_train shape: (n_train, ) +# loss_test shape: (n_test, ) + +summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, return_dict=True) +print(results) +# -> {'auc': 0.7044, +# 'best_attacker_auc': 'all_thresh_loss_auc', +# 'advantage': 0.3116, +# 'best_attacker_auc': 'all_thresh_loss_advantage'} +``` + +> NOTE: The keyword argument `return_dict` specified whether in addition to the +> `summary` the function also returns a python dictionary with the results. + +If the model is a classifier, the logits or output probabilities (i.e., the +softmax of logits) can also be provided to perform stronger attacks. + +> NOTE: The `logits_train` and `logits_test` arguments can also be filled with +> output probabilities per class ("posteriors"). + +```python +# logits_train shape: (n_train, n_classes) +# logits_test shape: (n_test, n_classes) + +summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train, + logits_test, return_dict=True) +print(results) +# -> {'auc': 0.5382, +# 'best_attacker_auc': 'all_lr_logits_loss_test_auc', +# 'advantage': 0.0572, +# 'best_attacker_auc': 'all_mlp_logits_loss_test_advantage'} +``` + +The `summary` will be a string in natural language describing the results in +more detail, e.g., + +``` +========== AUC ========== +The best attack (all_lr_logits_loss_test_auc) achieved an auc of 0.5382. + +========== ADVANTAGE ========== +The best attack (all_mlp_logits_loss_test_advantage) achieved an advantage of 0.0572. +``` + +Similarly, we can run attacks on the logits alone, without access to losses: + +```python +summary, results = mia.run_all_attacks_and_create_summary(logits_train=logits_train, + logits_test=logits_test, + return_dict=True) +print(results) +# -> {'auc': 0.9278, +# 'best_attacker_auc': 'all_rf_logits_test_auc', +# 'advantage': 0.6991, +# 'best_attacker_auc': 'all_rf_logits_test_advantage'} +``` + +### Advanced usage + +Finally, if we also have access to the true labels of the training and test +inputs, we can run the attacks for each class separately. If labels *and* logits +are provided, attacks only for misclassified (typically uncertain) examples are +also performed. + +```python +summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train, + logits_test, labels_train, labels_test, + return_dict=True) +``` + +Here, we now also get as output the class with the maximal vulnerability +according to our metrics (`max_vuln_class_auc`, `max_vuln_class_advantage`) +together with the corresponding values (`class__auc`, +`class__advantage`). The same values exist in the `results` dictionary +for `min` instead of `max`, i.e., the least vulnerable classes. Moreover, the +gap between the maximum and minimum values (`max_class_gap_auc`, +`max_class_gap_advantage`) is also provided. Similarly, the vulnerability +metrics when the attacks are restricted to the misclassified examples +(`misclassified_auc`, `misclassified_advantage`) are also shown. Finally, the +results also contain the number of examples in each of these groups, i.e., +within each of the reported classes as well as the number of misclassified +examples. The final `results` dictionary is of the form + +``` +{'auc': 0.9181, + 'best_attacker_auc': 'all_rf_logits_loss_test_auc', + 'advantage': 0.6915, + 'best_attacker_advantage': 'all_rf_logits_loss_test_advantage', + 'max_class_gap_auc': 0.254, + 'class_5_auc': 0.9512, + 'class_3_auc': 0.6972, + 'max_vuln_class_auc': 5, + 'min_vuln_class_auc': 3, + 'max_class_gap_advantage': 0.5073, + 'class_0_advantage': 0.8086, + 'class_3_advantage': 0.3013, + 'max_vuln_class_advantage': 0, + 'min_vuln_class_advantage': 3, + 'misclassified_n_examples': 4513.0, + 'class_0_n_examples': 899.0, + 'class_1_n_examples': 900.0, + 'class_2_n_examples': 931.0, + 'class_3_n_examples': 893.0, + 'class_4_n_examples': 960.0, + 'class_5_n_examples': 884.0} +``` + +### Setting the precision of the reported results + +Finally, `run_all_attacks_and_create_summary` takes one extra keyword argument +`decimals`, expecting a positive integer. This sets the precision of all result +values as the number of decimals to report. It defaults to 4. + +## Run all attacks and get all outputs + +With the `run_all_attacks` function, one can run all implemented attacks on all +possible subsets of the data (all examples, split by class, split by confidence +deciles, misclassified only). This function returns a relatively large +dictionary with all attack results. This is the most detailed information one +could get about these types of membership inference attacks (besides plots for +each attack, see next section.) This is useful if you know exactly what you're +looking for. + +> NOTE: The `run_all_attacks` function takes as an additional argument which +> trained attackers to run. In the `run_all_attacks_and_create_summary`, only +> logistic regression (`lr`) is trained as a binary classifier to distinguish +> in-training form out-of-training examples. In addition, with the +> `attack_classifiers` argument, one can add multi-layered perceptrons (`mlp`), +> random forests (`rf`), and k-nearest-neighbors (`knn`) or any subset thereof +> for the attack models. Note that these classifiers may not converge. + +```python +mia.run_all_attacks(loss_train, loss_test, logits_train, logits_test, + labels_train, labels_test, + attack_classifiers=('lr', 'mlp', 'rf', 'knn')) +``` + +Again, `run_all_attacks` can be called on all combinations of losses, logits, +probabilities, and labels as long as at least either losses or logits +(probabilities) are provided. + +## Fine grained control over individual attacks and plots + +The `run_attack` function exposes the underlying workhorse of the +`run_all_attacks` and `run_all_attacks_and_create_summary` functionality. It +allows for fine grained control of which attacks to run individually. + +As another key feature, this function also exposes options to store receiver +operator curve plots for the different attacks as well as histograms of losses +or the maximum logits/probabilities. Finally, we can also store all results +(including the values to reproduce the plots) to colossus. + +All options are explained in detail in the doc string of the `run_attack` +function. + +For example, to run a simple threshold attack on the losses only and store plots +and result data to colossus, run + +```python +data_path = '/Users/user/Desktop/test/' # set to None to not store data +figure_path = '/Users/user/Desktop/test/' # set to None to not store figures + +mia.attack(loss_train=loss_train, + loss_test=loss_test, + metric='auc', + output_directory=data_path, + figure_directory=figure_path) +``` + +Among other things, the `run_attack` functionality allows to control: + +* which metrics to output (`metric` argument, using `auc` or `advantage` or + both) +* which classifiers (logistic regression, multi-layered perceptrons, random + forests) to train as attackers beyond the simple threshold attacks + (`attack_classifiers`) +* to only attack a specific (set of) classes (`by_class`) +* to only attack specific percentiles of the data (`by_percentile`). + Percentiles here are computed by looking at the largest logit or probability + for each example, i.e., how confident the model is in its prediction. +* to only attack the misclassified examples (`only_misclassified`) +* not to balance examples between the in-training and out-of-training examples + using `balance`. By default an equal number of examples from train and test + are selected for the attacks (whichever is smaller). +* the test set size for trained attacks (`test_size`). When a classifier is + trained to distinguish between train and test examples, a train-test split + for that classifier itself is required. +* for the train-test split as well as for the class balancing randomness is + used with a seed specified by `random_state`. + +## Contact + +Reach out to tf-privacy@google.com and let us know how you’re using this module. +We’re keen on hearing your stories, feedback, and suggestions! + +## Copyright + +Copyright 2020 - Google LLC diff --git a/tensorflow_privacy/privacy/membership_inference_attack/__init__.py b/tensorflow_privacy/privacy/membership_inference_attack/__init__.py new file mode 100644 index 0000000..2225510 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py new file mode 100644 index 0000000..1f7694a --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -0,0 +1,716 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Code that runs membership inference attacks based on the model outputs.""" + +import collections +import io +import os +import re + +from typing import Text, Dict, Iterable, Tuple, Union, Any + +from absl import logging +import numpy as np +from scipy import special + +from tensorflow_privacy.privacy.membership_inference_attack import plotting +from tensorflow_privacy.privacy.membership_inference_attack import trained_attack_models +from tensorflow_privacy.privacy.membership_inference_attack import utils + +from os import mkdir + +ArrayDict = Dict[Text, np.ndarray] +FloatDict = Dict[Text, float] +AnyDict = Dict[Text, Any] +Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] +MetricNames = Union[Text, Iterable[Text]] + + +def _get_vulnerabilities(result: ArrayDict, metrics: MetricNames) -> FloatDict: + """Gets the vulnerabilities according to the chosen metrics for all attacks.""" + vulns = {} + if isinstance(metrics, str): + metrics = [metrics] + for k in result: + for metric in metrics: + if k.endswith(metric.lower()) or k.endswith('n_examples'): + vulns[k] = float(result[k]) + return vulns + + +def _get_maximum_vulnerability( + attack_result: FloatDict, + metrics: MetricNames, + filterby: Text = '') -> Dict[Text, Dict[Text, Union[Text, float]]]: + """Returns the worst vulnerability according to the chosen metrics of all attacks.""" + vulns = {} + if isinstance(metrics, str): + metrics = [metrics] + for metric in metrics: + best_attack_value = -np.inf + for k in attack_result: + if (k.startswith(filterby.lower()) and k.endswith(metric.lower()) and + 'train' not in k): + if float(attack_result[k]) > best_attack_value: + best_attack_value = attack_result[k] + best_attacker = k + if best_attack_value > -np.inf: + newkey = filterby + '-' + metric if filterby else metric + vulns[newkey] = {'value': best_attack_value, 'attacker': best_attacker} + return vulns + + +def _get_maximum_class_gap_or_none(result: FloatDict, + metrics: MetricNames) -> FloatDict: + """Returns the biggest and smallest vulnerability and the gap across classes.""" + gaps = {} + if isinstance(metrics, str): + metrics = [metrics] + for metric in metrics: + hi = -np.inf + lo = np.inf + hi_idx, lo_idx = -1, -1 + for k in result: + if (k.startswith('class') and k.endswith(metric.lower()) and + 'train' not in k): + if float(result[k]) > hi: + hi = float(result[k]) + hi_idx = int(re.findall(r'class_(\d+)_', k)[0]) + if float(result[k]) < lo: + lo = float(result[k]) + lo_idx = int(re.findall(r'class_(\d+)_', k)[0]) + if lo - hi < np.inf: + gaps['max_class_gap_' + metric] = hi - lo + gaps[f'class_{hi_idx}_' + metric] = hi + gaps[f'class_{lo_idx}_' + metric] = lo + gaps['max_vuln_class_' + metric] = hi_idx + gaps['min_vuln_class_' + metric] = lo_idx + return gaps + + +# ------------------------------------------------------------------------------ +# Attacks +# ------------------------------------------------------------------------------ + + +def _run_threshold_loss_attack(features: ArrayDict, + figure_file_prefix: Text = '', + figure_directory: Text = None) -> ArrayDict: + """Runs the threshold attack on the loss.""" + logging.info('Run threshold attack on loss...') + is_train = features['is_train'] + attack_prefix = 'thresh_loss' + tmp_results = utils.compute_performance_metrics(is_train, -features['loss']) + if figure_directory is not None: + figpath = os.path.join(figure_directory, + figure_file_prefix + attack_prefix + '.png') + plotting.save_plot( + plotting.plot_curve_with_area( + tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'), + figpath) + figpath = os.path.join(figure_directory, + figure_file_prefix + attack_prefix + '_hist.png') + plotting.save_plot( + plotting.plot_histograms( + features['loss'][is_train == 1], + features['loss'][is_train == 0], + xlabel='loss'), figpath) + return utils.prepend_to_keys(tmp_results, attack_prefix + '_') + + +def _run_threshold_attack_maxlogit(features: ArrayDict, + figure_file_prefix: Text = '', + figure_directory: Text = None) -> ArrayDict: + """Runs the threshold attack on the maximum logit.""" + is_train = features['is_train'] + preds = np.max(features['logits'], axis=-1) + tmp_results = utils.compute_performance_metrics(is_train, preds) + attack_prefix = 'thresh_maxlogit' + if figure_directory is not None: + figpath = os.path.join(figure_directory, + figure_file_prefix + attack_prefix + '.png') + plotting.save_plot( + plotting.plot_curve_with_area( + tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'), + figpath) + figpath = os.path.join(figure_directory, + figure_file_prefix + attack_prefix + '_hist.png') + plotting.save_plot( + plotting.plot_histograms( + preds[is_train == 1], preds[is_train == 0], xlabel='loss'), figpath) + return utils.prepend_to_keys(tmp_results, attack_prefix + '_') + + +def _run_trained_attack(attack_classifier: Text, + data: Dataset, + attack_prefix: Text, + figure_file_prefix: Text = '', + figure_directory: Text = None) -> ArrayDict: + """Train a classifier for attack and evaluate it.""" + # Train the attack classifier + (x_train, y_train), (x_test, y_test) = data + clf_model = trained_attack_models.choose_model(attack_classifier) + clf_model.fit(x_train, y_train) + + # Calculate training set metrics + pred_train = clf_model.predict_proba(x_train)[:, clf_model.classes_ == 1] + results = utils.prepend_to_keys( + utils.compute_performance_metrics(y_train, pred_train), + attack_prefix + 'train_') + + # Calculate test set metrics + pred_test = clf_model.predict_proba(x_test)[:, clf_model.classes_ == 1] + results.update( + utils.prepend_to_keys( + utils.compute_performance_metrics(y_test, pred_test), + attack_prefix + 'test_')) + + if figure_directory is not None: + figpath = os.path.join(figure_directory, + figure_file_prefix + attack_prefix[:-1] + '.png') + plotting.save_plot( + plotting.plot_curve_with_area( + results[attack_prefix + 'test_fpr'], + results[attack_prefix + 'test_tpr'], + xlabel='FPR', + ylabel='TPR'), figpath) + return results + + +def _run_attacks_and_plot(features: ArrayDict, + attacks: Iterable[Text], + attack_classifiers: Iterable[Text], + balance: bool, + test_size: float, + random_state: int, + figure_file_prefix: Text = '', + figure_directory: Text = None) -> ArrayDict: + """Runs the specified attacks on the provided data.""" + if balance: + try: + features = utils.subsample_to_balance(features, random_state) + except RuntimeError: + logging.info('Not enough remaining data for attack: Empty results.') + return {} + + result = {} + # -------------------- Simple threshold attacks + if 'thresh_loss' in attacks: + result.update( + _run_threshold_loss_attack(features, figure_file_prefix, + figure_directory)) + + if 'thresh_maxlogit' in attacks: + result.update( + _run_threshold_attack_maxlogit(features, figure_file_prefix, + figure_directory)) + + # -------------------- Run learned attacks + # TODO(b/157632603): Add a prefix (for example 'trained_') for attacks which + # use classifiers to distinguish from threshould attacks. + if 'logits' in attacks: + data = utils.get_train_test_split( + features, add_loss=False, test_size=test_size) + for clf in attack_classifiers: + logging.info('Train %s on %d logits', clf, data[0][0].shape[1]) + attack_prefix = f'{clf}_logits_' + result.update( + _run_trained_attack(clf, data, attack_prefix, figure_file_prefix, + figure_directory)) + + if 'logits_loss' in attacks: + data = utils.get_train_test_split( + features, add_loss=True, test_size=test_size) + for clf in attack_classifiers: + logging.info('Train %s on %d logits + loss', clf, data[0][0].shape[1]) + attack_prefix = f'{clf}_logits_loss_' + result.update( + _run_trained_attack(clf, data, attack_prefix, figure_file_prefix, + figure_directory)) + return result + + +def run_attack(loss_train: np.ndarray = None, + loss_test: np.ndarray = None, + logits_train: np.ndarray = None, + logits_test: np.ndarray = None, + labels_train: np.ndarray = None, + labels_test: np.ndarray = None, + attack_classifiers: Iterable[Text] = None, + only_misclassified: bool = False, + by_class: Union[bool, Iterable[int], int] = False, + by_percentile: Union[bool, Iterable[int], int] = False, + figure_directory: Text = None, + output_directory: Text = None, + metric: MetricNames = 'auc', + balance: bool = True, + test_size: float = 0.2, + random_state: int = 0) -> FloatDict: + """Run membership inference attack(s). + + Based only on specific outputs of a machine learning model on some examples + used for training (train) and some examples not used for training (test), run + membership inference attacks that try to discriminate training from test + inputs based only on the model outputs. + While all inputs are optional, at least one train/test pair is required to run + any attacks (either losses or logits/probabilities). + Note that one can equally provide output probabilities instead of logits in + the logits_train / logits_test arguments. + + We measure the vulnerability of the model via the area under the ROC-curve + (auc) or via max |fpr - tpr| (advantage) of the attack classifier. These + measures are very closely related and may look almost indistinguishable. + + This function provides relatively fine grained control and outputs detailed + results. For a higher-level wrapper with sane internal default settings and + distilled output results, see `run_all_attacks`. + + Via the `figure_directory` argument and the `output_directory` argument more + detailed information as well as roc-curve plots can optionally be stored to + disk. + + If `loss_train` and `loss_test` are provided we run: + - simple threshold attack on the loss + + If `logits_train` and `logits_test` are provided we run: + - simple threshold attack on the top logit + - if `attack_classifiers` is not None and no losses are provided: train the + specified classifiers on the top 10 logits (or all logits if there are + less than 10) + - if `attack_classifiers` is not None and losses are provided: train the + specified classifiers on the top 10 logits (or all logits if there are + less than 10) and the loss + + Args: + loss_train: A 1D array containing the individual scalar losses for examples + used during training. + loss_test: A 1D array containing the individual scalar losses for examples + not used during training. + logits_train: A 2D array (n_train, n_classes) of the individual logits or + output probabilities of examples used during training. + logits_test: A 2D array (n_test, n_classes) of the individual logits or + output probabilities of examples not used during training. + labels_train: The true labels of the training examples. Labels are only + needed when `by_class` is specified (i.e., not False). + labels_test: The true labels of the test examples. Labels are only needed + when `by_class` is specified (i.e., not False). + attack_classifiers: Attack classifiers to train beyond simple thresholding + that require training a simple binary ML classifier. This argument is + ignored if logits are not provided. Classifiers can be 'lr' for logistic + regression, 'mlp' for multi-layered perceptron, 'rf' for random forests, + or 'knn' for k-nearest-neighbors. If 'None', don't train classifiers + beyond simple thresholding. + only_misclassified: Run and evaluate attacks only on misclassified examples. + Must specify `labels_train`, `labels_test`, `logits_train` and + `logits_test` to use this. If this is True, `by_class` and `by_percentile` + are ignored. + by_class: This argument determines whether attacks are run on the entire + data, or on examples grouped by their class label. If `True`, all attacks + are run separately for each class. If `by_class` is a single integer, run + attacks for this class only. If `by_class` is an iterable of integers, run + all attacks for each of the specified class labels separately. Only used + if `labels_train` and `labels_test` are specified. If `by_class` is + specified (not False), `by_percentile` is ignored. Ignored if + `only_misclassified` is True. + by_percentile: This argument determines whether attacks are run on the + entire data, or separately for examples where the most likely class + prediction is within a given percentile of all maximum predicitons. If + `True`, all attacks are run separately for the examples with max + probabilities within the ten deciles. If `by_precentile` is a single int + between 0 and 100, run attacks only for examples with confidence within + this percentile. If `by_percentile` is an iterable of ints between 0 and + 100, run all attacks for each of the specified percentiles separately. + Ignored if `by_class` is specified. Ignored if `logits_train` and + `logits_test` are not specified. Ignored if `only_misclassified` is True. + figure_directory: Where to store ROC-curve plots and histograms. If `None`, + don't create plots. + output_directory: Where to store detailed result data for all run attacks. + If `None`, don't store detailed result data. + metric: Available vulnerability metrics are 'auc' or 'advantage' for the + area under the ROC curve or the advantage (max |tpr - fpr|). Specify + either one of them or both. + balance: Whether to use the same number of train and test samples (by + randomly subsampling whichever happens to be larger). + test_size: The fraction of the input data to use for the evaluation of + trained ML attacks. This argument is ignored, if either attack_classifiers + is None, or no logits are provided. + random_state: Random seed for reproducibility. Only used if attack models + are trained. + + Returns: + results: Dictionary with the chosen vulnerability metric(s) for all ran + attacks. + """ + attacks = [] + features = {} + # ---------- Check available data ---------- + if ((loss_train is None or loss_test is None) and + (logits_train is None or logits_test is None)): + raise ValueError( + 'Need at least train and test for loss or train and test for logits.') + + # ---------- If losses are provided ---------- + if loss_train is not None and loss_test is not None: + if loss_train.ndim != 1 or loss_test.ndim != 1: + raise ValueError('Losses must be 1D arrays.') + features['is_train'] = np.concatenate( + (np.ones(len(loss_train)), np.zeros(len(loss_test))), + axis=0).astype(int) + features['loss'] = np.concatenate((loss_train.ravel(), loss_test.ravel()), + axis=0) + attacks.append('thresh_loss') + + # ---------- If logits are provided ---------- + if logits_train is not None and logits_test is not None: + assert logits_train.ndim == 2 and logits_test.ndim == 2, \ + 'Logits must be 2D arrays.' + assert logits_train.shape[1] == logits_test.shape[1], \ + 'Train and test logits must agree along axis 1 (number of classes).' + if 'is_train' in features: + assert (loss_train.shape[0] == logits_train.shape[0] and + loss_test.shape[0] == logits_test.shape[0]), \ + 'Number of examples must match between loss and logits.' + else: + features['is_train'] = np.concatenate( + (np.ones(logits_train.shape[0]), np.zeros(logits_test.shape[0])), + axis=0).astype(int) + attacks.append('thresh_maxlogit') + features['logits'] = np.concatenate((logits_train, logits_test), axis=0) + if attack_classifiers: + attacks.append('logits') + if 'loss' in features: + attacks.append('logits_loss') + + # ---------- If labels are provided ---------- + if labels_train is not None and labels_test is not None: + if labels_train.ndim != 1 or labels_test.ndim != 1: + raise ValueError('Losses must be 1D arrays.') + if 'loss' in features: + assert (loss_train.shape[0] == labels_train.shape[0] and + loss_test.shape[0] == labels_test.shape[0]), \ + 'Number of examples must match between loss and labels.' + else: + assert (logits_train.shape[0] == labels_train.shape[0] and + logits_test.shape[0] == labels_test.shape[0]), \ + 'Number of examples must match between logits and labels.' + features['label'] = np.concatenate((labels_train, labels_test), axis=0) + + # ---------- Data subsampling or filtering ---------- + filtertype = None + filtervals = [None] + if only_misclassified: + if (labels_train is None or labels_test is None or logits_train is None or + logits_test is None): + raise ValueError('Must specify labels_train, labels_test, logits_train, ' + 'and logits_test for the only_misclassified option.') + filtertype = 'misclassified' + elif by_class: + if labels_train is None or labels_test is None: + raise ValueError('Must specify labels_train and labels_test when using ' + 'the by_class option.') + if isinstance(by_class, bool): + filtervals = list(set(labels_train) | set(labels_test)) + elif isinstance(by_class, int): + filtervals = [by_class] + elif isinstance(by_class, collections.Iterable): + filtervals = list(by_class) + filtertype = 'class' + elif by_percentile: + if logits_train is None or logits_test is None: + raise ValueError('Must specify logits_train and logits_test when using ' + 'the by_percentile option.') + if isinstance(by_percentile, bool): + filtervals = list(range(10, 101, 10)) + elif isinstance(by_percentile, int): + filtervals = [by_percentile] + elif isinstance(by_percentile, collections.Iterable): + filtervals = [int(percentile) for percentile in by_percentile] + filtertype = 'percentile' + + # ---------- Need to create figure directory? ---------- + if figure_directory is not None: + mkdir(figure_directory) + + # ---------- Actually run attacks and plot if required ---------- + logging.info('Selecting %s with values %s', filtertype, filtervals) + num = None + result = {} + for filterval in filtervals: + if filtertype is None: + tmp_features = features + elif filtertype == 'misclassified': + idx = features['label'] != np.argmax(features['logits'], axis=-1) + tmp_features = utils.select_indices(features, idx) + num = np.sum(idx) + elif filtertype == 'class': + idx = features['label'] == filterval + tmp_features = utils.select_indices(features, idx) + num = np.sum(idx) + elif filtertype == 'percentile': + certainty = np.max(special.softmax(features['logits'], axis=-1), axis=-1) + idx = certainty <= np.percentile(certainty, filterval) + tmp_features = utils.select_indices(features, idx) + + prefix = f'{filtertype}_' if filtertype is not None else '' + prefix += f'{filterval}_' if filterval is not None else '' + tmp_result = _run_attacks_and_plot(tmp_features, attacks, + attack_classifiers, balance, test_size, + random_state, prefix, figure_directory) + if num is not None: + tmp_result['n_examples'] = float(num) + if tmp_result: + result.update(utils.prepend_to_keys(tmp_result, prefix)) + + # ---------- Store data ---------- + if output_directory is not None: + mkdir(output_directory) + resultpath = os.path.join(output_directory, 'attack_results.npz') + logging.info('Store aggregate results at %s.', resultpath) + with open(resultpath, 'wb') as fp: + io_buffer = io.BytesIO() + np.savez(io_buffer, **result) + fp.write(io_buffer.getvalue()) + + return _get_vulnerabilities(result, metric) + + +def run_all_attacks(loss_train: np.ndarray = None, + loss_test: np.ndarray = None, + logits_train: np.ndarray = None, + logits_test: np.ndarray = None, + labels_train: np.ndarray = None, + labels_test: np.ndarray = None, + attack_classifiers: Iterable[Text] = ('lr', 'mlp', 'rf', + 'knn'), + decimals: Union[int, None] = 4) -> FloatDict: + """Runs all possible membership inference attacks. + + Check 'run_attack' for detailed information of how attacks are performed + and evaluated. + + This function internally chooses sane default settings for all attacks and + returns all possible output combinations. + For fine grained control and partial attacks, please see `run_attack`. + + Args: + loss_train: A 1D array containing the individual scalar losses for examples + used during training. + loss_test: A 1D array containing the individual scalar losses for examples + not used during training. + logits_train: A 2D array (n_train, n_classes) of the individual logits or + output probabilities of examples used during training. + logits_test: A 2D array (n_test, n_classes) of the individual logits or + output probabilities of examples not used during training. + labels_train: The true labels of the training examples. Labels are only + needed when `by_class` is specified (i.e., not False). + labels_test: The true labels of the test examples. Labels are only needed + when `by_class` is specified (i.e., not False). + attack_classifiers: Which binary classifiers to train (in addition to simple + threshold attacks). This can include 'lr' (logistic regression), 'mlp' + (multi-layered perceptron), 'rf' (random forests), 'knn' (k-nearest + neighbors), which will be trained with cross validation to determine good + hyperparameters. + decimals: Round all float results to this number of decimals. If decimals is + None, don't round. + + Returns: + result: dictionary with all attack results + """ + metrics = ['auc', 'advantage'] + + # Entire data + result = run_attack( + loss_train, + loss_test, + logits_train, + logits_test, + attack_classifiers=attack_classifiers, + metric=metrics) + result = utils.prepend_to_keys(result, 'all_') + + # Misclassified examples + if (labels_train is not None and labels_test is not None and + logits_train is not None and logits_test is not None): + result.update( + run_attack( + loss_train, + loss_test, + logits_train, + logits_test, + labels_train, + labels_test, + attack_classifiers=attack_classifiers, + only_misclassified=True, + metric=metrics)) + + # Split per class + if labels_train is not None and labels_test is not None: + result.update( + run_attack( + loss_train, + loss_test, + logits_train, + logits_test, + labels_train, + labels_test, + by_class=True, + attack_classifiers=attack_classifiers, + metric=metrics)) + + # Different deciles + if logits_train is not None and logits_test is not None: + result.update( + run_attack( + loss_train, + loss_test, + logits_train, + logits_test, + by_percentile=True, + attack_classifiers=attack_classifiers, + metric=metrics)) + + if decimals is not None: + result = {k: round(v, decimals) for k, v in result.items()} + + return result + + +def run_all_attacks_and_create_summary( + loss_train: np.ndarray = None, + loss_test: np.ndarray = None, + logits_train: np.ndarray = None, + logits_test: np.ndarray = None, + labels_train: np.ndarray = None, + labels_test: np.ndarray = None, + return_dict: bool = True, + decimals: Union[int, None] = 4) -> Union[Text, Tuple[Text, AnyDict]]: + """Runs all possible membership inference attack(s) and distill results. + + Check 'run_attack' for detailed information of how attacks are performed + and evaluated. + + This function internally chooses sane default settings for all attacks and + returns all possible output combinations. + For fine grained control and partial attacks, please see `run_attack`. + + Args: + loss_train: A 1D array containing the individual scalar losses for examples + used during training. + loss_test: A 1D array containing the individual scalar losses for examples + not used during training. + logits_train: A 2D array (n_train, n_classes) of the individual logits or + output probabilities of examples used during training. + logits_test: A 2D array (n_test, n_classes) of the individual logits or + output probabilities of examples not used during training. + labels_train: The true labels of the training examples. Labels are only + needed when `by_class` is specified (i.e., not False). + labels_test: The true labels of the test examples. Labels are only needed + when `by_class` is specified (i.e., not False). + return_dict: Whether to also return a dictionary with the results summarized + in the summary string. + decimals: Round all float results to this number of decimals. If decimals is + None, don't round. + + Returns: + summarystring: A string with natural language summary of the attacks. In the + summary string printed numbers will be rounded to `decimal` decimals if + provided, otherwise will round to 3 diits by default for readability. + result: a dictionary with all the distilled attack information summarized + in the summarystring + """ + summary = [] + metrics = ['auc', 'advantage'] + attack_classifiers = ['lr', 'rf', 'mlp', 'knn'] + results = run_all_attacks( + loss_train, + loss_test, + logits_train, + logits_test, + labels_train, + labels_test, + attack_classifiers=attack_classifiers, + decimals=None) + output = _get_maximum_vulnerability(results, metrics, filterby='all') + + if decimals is not None: + strdec = decimals + else: + strdec = 4 + + for metric in metrics: + summary.append(f'========== {metric.upper()} ==========') + best_value = output['all-' + metric]['value'] + best_attacker = output['all-' + metric]['attacker'] + summary.append(f'The best attack ({best_attacker}) achieved an {metric} of ' + f'{best_value:.{strdec}f}.') + summary.append('') + + classgap = _get_maximum_class_gap_or_none(results, metrics) + if classgap: + output.update(classgap) + for metric in metrics: + summary.append(f'========== {metric.upper()} per class ==========') + hi_idx = output[f'max_vuln_class_{metric}'] + lo_idx = output[f'min_vuln_class_{metric}'] + hi = output[f'class_{hi_idx}_{metric}'] + lo = output[f'class_{lo_idx}_{metric}'] + gap = output[f'max_class_gap_{metric}'] + summary.append(f'The most vulnerable class {hi_idx} has {metric} of ' + f'{hi:.{strdec}f}.') + summary.append(f'The least vulnerable class {lo_idx} has {metric} of ' + f'{lo:.{strdec}f}.') + summary.append(f'=> The maximum gap between class vulnerabilities is ' + f'{gap:.{strdec}f}.') + summary.append('') + + misclassified = _get_maximum_vulnerability( + results, metrics, filterby='misclassified') + if misclassified: + for metric in metrics: + best_attacker = misclassified['misclassified-' + metric]['attacker'] + summary.append(f'========== {metric.upper()} for misclassified ' + '==========') + summary.append('Among misclassified examples, the best attack ' + f'({best_attacker}) achieved an {metric} of ' + f'{best_value:.{strdec}f}.') + summary.append('') + output.update(misclassified) + + n_examples = {k: v for k, v in results.items() if k.endswith('n_examples')} + if n_examples: + output.update(n_examples) + + # Flatten remaining dicts in output + fresh_output = {} + for k, v in output.items(): + if isinstance(v, dict): + if k.startswith('all'): + fresh_output[k[4:]] = v['value'] + fresh_output['best_attacker_' + k[4:]] = v['attacker'] + else: + fresh_output[k] = v + output = fresh_output + + if decimals is not None: + for k, v in output.items(): + if isinstance(v, float): + output[k] = round(v, decimals) + + summary = '\n'.join(summary) + if return_dict: + return summary, output + else: + return summary diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py new file mode 100644 index 0000000..8e571ad --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -0,0 +1,307 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils.""" +from absl.testing import absltest + +import numpy as np + +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia + + +def get_result_dict(): + """Get an example result dictionary.""" + return { + 'test_n_examples': np.ones(1), + 'test_examples': np.zeros(1), + 'test_auc': np.ones(1), + 'test_advantage': np.ones(1), + 'all_0-metric': np.array([1]), + 'all_1-metric': np.array([2]), + 'test_2-metric': np.array([3]), + 'test_score': np.array([4]), + } + + +def get_test_inputs(): + """Get example inputs for attacks.""" + n_train = n_test = 500 + rng = np.random.RandomState(4) + loss_train = rng.randn(n_train) - 0.4 + loss_test = rng.randn(n_test) + 0.4 + logits_train = rng.randn(n_train, 5) + 0.2 + logits_test = rng.randn(n_test, 5) - 0.2 + labels_train = np.array([i % 5 for i in range(n_train)]) + labels_test = np.array([(3 * i) % 5 for i in range(n_test)]) + return (loss_train, loss_test, logits_train, logits_test, + labels_train, labels_test) + + +class GetVulnerabilityTest(absltest.TestCase): + + def test_get_vulnerabilities(self): + """Test extraction of vulnerability scores.""" + testdict = get_result_dict() + for key in ['auc', 'advantage']: + res = mia._get_vulnerabilities(testdict, key) + self.assertLen(res, 2) + self.assertEqual(res[f'test_{key}'], 1) + self.assertEqual(res['test_n_examples'], 1) + + res = mia._get_vulnerabilities(testdict, ['auc', 'advantage']) + self.assertLen(res, 3) + self.assertEqual(res['test_auc'], 1) + self.assertEqual(res['test_advantage'], 1) + self.assertEqual(res['test_n_examples'], 1) + + +class GetMaximumVulnerabilityTest(absltest.TestCase): + + def test_get_maximum_vulnerability(self): + """Test extraction of maximum vulnerability score.""" + testdict = get_result_dict() + for i in range(3): + key = f'{i}-metric' + res = mia._get_maximum_vulnerability(testdict, key) + self.assertLen(res, 1) + self.assertEqual(res[key]['value'], i + 1) + if i < 2: + self.assertEqual(res[key]['attacker'], f'all_{i}-metric') + else: + self.assertEqual(res[key]['attacker'], 'test_2-metric') + + res = mia._get_maximum_vulnerability(testdict, 'metric') + self.assertLen(res, 1) + self.assertEqual(res['metric']['value'], 3) + + res = mia._get_maximum_vulnerability(testdict, ['metric'], + filterby='all') + self.assertLen(res, 1) + self.assertEqual(res['all-metric']['value'], 2) + + res = mia._get_maximum_vulnerability(testdict, ['metric', 'score']) + self.assertLen(res, 2) + self.assertEqual(res['metric']['value'], 3) + self.assertEqual(res['score']['value'], 4) + self.assertEqual(res['score']['attacker'], 'test_score') + + +class ThresholdAttackLossTest(absltest.TestCase): + + def test_threshold_attack_loss(self): + """Test simple threshold attack on loss.""" + features = { + 'loss': np.zeros(10), + 'is_train': np.concatenate((np.zeros(5), np.ones(5))), + } + res = mia._run_threshold_loss_attack(features) + for k in res: + self.assertStartsWith(k, 'thresh_loss') + self.assertEqual(res['thresh_loss_auc'], 0.5) + self.assertEqual(res['thresh_loss_advantage'], 0.0) + + rng = np.random.RandomState(4) + n_train = 1000 + n_test = 500 + loss_train = rng.randn(n_train) - 0.4 + loss_test = rng.randn(n_test) + 0.4 + features = { + 'loss': np.concatenate((loss_train, loss_test)), + 'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))), + } + res = mia._run_threshold_loss_attack(features) + self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75) + self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35) + + +class ThresholdAttackMaxlogitTest(absltest.TestCase): + + def test_threshold_attack_maxlogits(self): + """Test simple threshold attack on maximum logit.""" + features = { + 'logits': np.eye(10, 14), + 'is_train': np.concatenate((np.zeros(5), np.ones(5))), + } + res = mia._run_threshold_attack_maxlogit(features) + for k in res: + self.assertStartsWith(k, 'thresh_maxlogit') + self.assertEqual(res['thresh_maxlogit_auc'], 0.5) + self.assertEqual(res['thresh_maxlogit_advantage'], 0.0) + + rng = np.random.RandomState(4) + n_train = 1000 + n_test = 500 + logits_train = rng.randn(n_train, 12) + 0.2 + logits_test = rng.randn(n_test, 12) - 0.2 + features = { + 'logits': np.concatenate((logits_train, logits_test), axis=0), + 'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))), + } + res = mia._run_threshold_attack_maxlogit(features) + self.assertBetween(res['thresh_maxlogit_auc'], 0.7, 0.75) + self.assertBetween(res['thresh_maxlogit_advantage'], 0.3, 0.35) + + +class TrainedAttackTrivialTest(absltest.TestCase): + + def test_trained_attack(self): + """Test trained attacks.""" + # Trivially easy problem + x_train, x_test = np.ones((500, 3)), np.ones((20, 3)) + x_train[:200] *= -1 + x_test[:8] *= -1 + y_train, y_test = np.ones(500).astype(int), np.ones(20).astype(int) + y_train[:200] = 0 + y_test[:8] = 0 + data = (x_train, y_train), (x_test, y_test) + for clf in ['lr', 'rf', 'mlp', 'knn']: + res = mia._run_trained_attack(clf, data, attack_prefix='a-') + self.assertEqual(res['a-train_auc'], 1) + self.assertEqual(res['a-test_auc'], 1) + self.assertEqual(res['a-train_advantage'], 1) + self.assertEqual(res['a-test_advantage'], 1) + + +class TrainedAttackRandomFeaturesTest(absltest.TestCase): + + def test_trained_attack(self): + """Test trained attacks.""" + # Random labels and features + rng = np.random.RandomState(4) + x_train, x_test = rng.randn(500, 3), rng.randn(500, 3) + y_train = rng.binomial(1, 0.5, size=(500,)) + y_test = rng.binomial(1, 0.5, size=(500,)) + data = (x_train, y_train), (x_test, y_test) + for clf in ['lr', 'rf', 'mlp', 'knn']: + res = mia._run_trained_attack(clf, data, attack_prefix='a-') + self.assertBetween(res['a-train_auc'], 0.5, 1.) + self.assertBetween(res['a-test_auc'], 0.4, 0.6) + self.assertBetween(res['a-train_advantage'], 0., 1.0) + self.assertBetween(res['a-test_advantage'], 0., 0.2) + + +class AttackLossesTest(absltest.TestCase): + + def test_attack(self): + """Test individual attack function.""" + # losses only, both metrics + loss_train, loss_test, _, _, _, _ = get_test_inputs() + res = mia.run_attack(loss_train, loss_test, metric=('auc', 'advantage')) + self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75) + self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35) + + +class AttackLossesLogitsTest(absltest.TestCase): + + def test_attack(self): + """Test individual attack function.""" + # losses and logits, two classifiers, single metric + loss_train, loss_test, logits_train, logits_test, _, _ = get_test_inputs() + res = mia.run_attack( + loss_train, + loss_test, + logits_train, + logits_test, + attack_classifiers=('rf', 'knn'), + metric='auc') + self.assertBetween(res['rf_logits_test_auc'], 0.7, 0.9) + self.assertBetween(res['knn_logits_test_auc'], 0.7, 0.9) + self.assertBetween(res['rf_logits_loss_test_auc'], 0.7, 0.9) + self.assertBetween(res['knn_logits_loss_test_auc'], 0.7, 0.9) + + +class AttackLossesLabelsByClassTest(absltest.TestCase): + + def test_attack(self): + # losses and labels, single metric, split by class + loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs() + n_train = loss_train.shape[0] + n_test = loss_test.shape[0] + res = mia.run_attack( + loss_train, + loss_test, + labels_train=labels_train, + labels_test=labels_test, + by_class=True, + metric='auc') + self.assertLen(res, 10) + for k in res: + self.assertStartsWith(k, 'class_') + if k.endswith('n_examples'): + self.assertEqual(int(res[k]), (n_train + n_test) // 5) + else: + self.assertBetween(res[k], 0.65, 0.75) + + +class AttackLossesLabelsSingleClassTest(absltest.TestCase): + + def test_attack(self): + # losses and labels, both metrics, single class + loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs() + n_train = loss_train.shape[0] + n_test = loss_test.shape[0] + res = mia.run_attack( + loss_train, + loss_test, + labels_train=labels_train, + labels_test=labels_test, + by_class=2, + metric=('auc', 'advantage')) + self.assertLen(res, 3) + for k in res: + self.assertStartsWith(k, 'class_2') + if k.endswith('n_examples'): + self.assertEqual(int(res[k]), (n_train + n_test) // 5) + elif k.endswith('advantage'): + self.assertBetween(res[k], 0.3, 0.5) + elif k.endswith('auc'): + self.assertBetween(res[k], 0.7, 0.75) + + +class AttackLogitsLabelsMisclassifiedTest(absltest.TestCase): + + def test_attack(self): + # logits and labels, single metric, single classifier, misclassified only + (_, _, logits_train, logits_test, + labels_train, labels_test) = get_test_inputs() + res = mia.run_attack( + logits_train=logits_train, + logits_test=logits_test, + labels_train=labels_train, + labels_test=labels_test, + only_misclassified=True, + attack_classifiers=('lr',), + metric='advantage') + self.assertBetween(res['misclassified_lr_logits_test_advantage'], 0.3, 0.8) + self.assertEqual(res['misclassified_n_examples'], 802) + + +class AttackLogitsByPrecentileTest(absltest.TestCase): + + def test_attack(self): + # only logits, single metric, no classifiers, split by deciles + _, _, logits_train, logits_test, _, _ = get_test_inputs() + res = mia.run_attack( + logits_train=logits_train, + logits_test=logits_test, + by_percentile=True, + metric='auc') + for k in res: + self.assertStartsWith(k, 'percentile') + self.assertBetween(res[k], 0.60, 0.75) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/plotting.py b/tensorflow_privacy/privacy/membership_inference_attack/plotting.py new file mode 100644 index 0000000..ad415c2 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/plotting.py @@ -0,0 +1,80 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Plotting functionality for membership inference attack analysis. + +Functions to plot ROC curves and histograms as well as functionality to store +figures to colossus. +""" + +from typing import Text, Iterable + +import matplotlib.pyplot as plt +import numpy as np +from sklearn import metrics + + +def save_plot(figure: plt.Figure, path: Text, outformat='png'): + """Store a figure to disk.""" + if path is not None: + with open(path, 'wb') as f: + figure.savefig(f, bbox_inches='tight', format=outformat) + plt.close(figure) + + +def plot_curve_with_area(x: Iterable[float], + y: Iterable[float], + xlabel: Text = 'x', + ylabel: Text = 'y') -> plt.Figure: + """Plot the curve defined by inputs and the area under the curve. + + All entries of x and y are required to lie between 0 and 1. + For example, x could be recall and y precision, or x is fpr and y is tpr. + + Args: + x: Values on x-axis (1d) + y: Values on y-axis (must be same length as x) + xlabel: Label for x axis + ylabel: Label for y axis + + Returns: + The matplotlib figure handle + """ + fig = plt.figure() + plt.plot([0, 1], [0, 1], 'k', lw=1.0) + plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}') + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.legend() + return fig + + +def plot_histograms(train: Iterable[float], + test: Iterable[float], + xlabel: Text = 'x', + thresh: float = None) -> plt.Figure: + """Plot histograms of training versus test metrics.""" + xmin = min(np.min(train), np.min(test)) + xmax = max(np.max(train), np.max(test)) + bins = np.linspace(xmin, xmax, 100) + fig = plt.figure() + plt.hist(test, bins=bins, density=True, alpha=0.5, label='test', log='y') + plt.hist(train, bins=bins, density=True, alpha=0.5, label='train', log='y') + if thresh is not None: + plt.axvline(thresh, c='r', label=f'threshold = {thresh:.3f}') + plt.xlabel(xlabel) + plt.ylabel('normalized counts (density)') + plt.legend() + return fig diff --git a/tensorflow_privacy/privacy/membership_inference_attack/run_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/run_attack.py new file mode 100644 index 0000000..b57aa77 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/run_attack.py @@ -0,0 +1,217 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +r"""This module contains code to run attacks on previous model outputs. + +Provided a path to a dataset of model outputs (logits, output probabilities, +losses, labels, predictions, membership indicators (is train or not)), we train +supervised binary classifiers using variable sets of features to distinguish +training from testing examples. We also provide threshold attacks, i.e., simply +thresholing losses or the maximum probability/logit to obtain binary +predictions. + +The input data is assumed to be a tf.example proto stored with RecordIO (.rio). +For example, outputs in an accepted format are typically produced by the +`extract` script in the `extract` directory. + +We run various attacks on each of the full datasets, split by class, split by +percentile of the most certain prediction and only on misclassified examples and +record the area under the receiver operator curve as well as the attack +advantage (i.e., max |tpr - fpr|) as vulnerability metrics. For all metrics +recorded, see the doc string of `membership_inference_attack.all_attacks`. +In addition, we record the overall training and test accuracy and loss of the +original image classifier. All these results are collected in a single +dictionary with descriptive keys. If there exist multiple model checkpoints (at +different training epochs), the results for each checkpoint are concatenated, +such that the dictionary keys stay the same, but the values contain arrays (the +size being the number of checkpoints). This overall result dicitonary is then +stored as a binary (and compressed) numpy file: .npz. +This file is stored either in the provided output path. If that is the empty +string, it is stored on the same level as the inputdir with the chosen name. +Using `attack_results.npz` by default. +""" + +# Example usage: + +python run_attack.py --dataset=cifar10 --inputdir="attack_data" +The results are then stored at ./attack_data + +import io +import os +import re + +from typing import Text, Dict + +from absl import app +from absl import flags +from absl import logging + +import numpy as np +import tensorflow.google as tf +import tensorflow_datasets as tfds + +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.membership_inference_attack import utils + +from glob import glob + +Result = Dict[Text, np.ndarray] + +FLAGS = flags.FLAGS + +flags.DEFINE_float('test_size', 0.2, + 'Fraction of attack data used for the test set.') +flags.DEFINE_string('dataset', 'cifar10', 'The dataset to use.') +flags.DEFINE_string( + 'output', '', 'The path where to store the results. ' + 'If empty string, store on same level as `inputdir` using ' + 'the name specified in the result_name flag.') +flags.DEFINE_string('result_name', 'attack_results.npz', + 'The name of the output npz file with the attack results.') +flags.DEFINE_string( + 'inputdir', + 'attack_data', + 'The input directory containing the attack datasets.') +flags.DEFINE_integer('seed', 43, 'Random seed to ensure same data splits.') + +# ------------------------------------------------------------------------------ +# Load and select features for attacks +# ------------------------------------------------------------------------------ + + +def load_all_features(data_path: Text) -> Result: + """Extract the selected features from a given dataset.""" + if FLAGS.dataset == 'cifar100': + num_classes = 100 + elif FLAGS.dataset in ['cifar10', 'mnist']: + num_classes = 10 + else: + raise ValueError(f'Unknown dataset {FLAGS.dataset}') + + features = { + 'logits': tf.FixedLenFeature((num_classes,), tf.float32), + 'prob': tf.FixedLenFeature((num_classes,), tf.float32), + 'loss': tf.FixedLenFeature([], tf.float32), + 'is_train': tf.FixedLenFeature([], tf.int64), + 'label': tf.FixedLenFeature([], tf.int64), + 'prediction': tf.FixedLenFeature([], tf.int64), + } + + dataset = tf.data.RecordIODataset(data_path) + + results = {k: [] for k in features} + ds = dataset.map(lambda x: tf.parse_single_example(x, features)) + for example in tfds.as_numpy(ds): + for k in results: + results[k].append(example[k]) + return utils.to_numpy(results) + + +# ------------------------------------------------------------------------------ +# Run attacks +# ------------------------------------------------------------------------------ + + +def run_all_attacks(data_path: Text): + """Train all possible attacks on the data at the given path.""" + logging.info('Load all features from %s...', data_path) + features = load_all_features(data_path) + + for k, v in features.items(): + logging.info('%s: %s', k, v.shape) + + logging.info('Compute original train/test accuracy and loss...') + train_idx = features['is_train'] == 1 + test_idx = np.logical_not(train_idx) + correct = features['label'] == features['prediction'] + result = { + 'original_train_loss': np.mean(features['loss'][train_idx]), + 'original_test_loss': np.mean(features['loss'][test_idx]), + 'original_train_acc': np.mean(correct[train_idx]), + 'original_test_acc': np.mean(correct[test_idx]), + } + + result.update( + mia.run_all_attacks( + loss_train=features['loss'][train_idx], + loss_test=features['loss'][test_idx], + logits_train=features['logits'][train_idx], + logits_test=features['logits'][test_idx], + labels_train=features['label'][train_idx], + labels_test=features['label'][test_idx], + attack_classifiers=('lr', 'mlp', 'rf', 'knn'), + decimals=None)) + result = utils.ensure_1d(result) + + logging.info('Finished training and evaluating attacks.') + return result + + +def attacking(): + """Load data and model and extract relevant outputs.""" + # ---------- Set result path ---------- + if FLAGS.output: + resultpath = FLAGS.output + else: + resultdir = FLAGS.inputdir + if resultdir[-1] == '/': + resultdir = resultdir[:-1] + resultdir = '/'.join(resultdir.split('/')[:-1]) + resultpath = os.path.join(resultdir, FLAGS.result_name) + + # ---------- Glob attack training sets ---------- + logging.info('Glob attack data paths...') + data_paths = sorted(glob(os.path.join(FLAGS.inputdir, '*'))) + logging.info('Found %d data paths', len(data_paths)) + + # ---------- Iterate over attack dataset and train attacks ---------- + epochs = [] + results = [] + for i, datapath in enumerate(data_paths): + logging.info('=' * 80) + logging.info('Attack model %d / %d', i + 1, len(data_paths)) + logging.info('=' * 80) + basename = os.path.basename(datapath) + found_ints = re.findall(r'(\d+)', basename) + if len(found_ints) == 1: + epoch = int(found_ints[0]) + logging.info('Found integer %d in pathname, interpret as epoch', epoch) + else: + epoch = np.nan + tmp_res = run_all_attacks(datapath) + if tmp_res is not None: + results.append(tmp_res) + epochs.append(epoch) + + # ---------- Aggregate and save results ---------- + logging.info('Aggregate and combine all results over epochs...') + results = utils.merge_dictionaries(results) + results['epochs'] = np.array(epochs) + logging.info('Store aggregate results at %s.', resultpath) + with open(resultpath, 'wb') as fp: + io_buffer = io.BytesIO() + np.savez(io_buffer, **results) + fp.write(io_buffer.getvalue()) + + logging.info('Finished attacks.') + + +def main(argv): + del argv # Unused + attacking() + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/trained_attack_models.py b/tensorflow_privacy/privacy/membership_inference_attack/trained_attack_models.py new file mode 100644 index 0000000..f03c035 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/trained_attack_models.py @@ -0,0 +1,106 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +r"""A collection of sklearn models for binary classification. + +This module contains some sklearn pipelines for finding models for binary +classification from a variable number of numerical input features. +These models are used to train binary classifiers for membership inference. +""" + +from typing import Text + +import numpy as np +from sklearn import ensemble +from sklearn import linear_model +from sklearn import model_selection +from sklearn import neighbors +from sklearn import neural_network + + +def choose_model(attack_classifier: Text): + """Choose a model based on a string classifier.""" + if attack_classifier == 'lr': + return logistic_regression() + elif attack_classifier == 'mlp': + return mlp() + elif attack_classifier == 'rf': + return random_forest() + elif attack_classifier == 'knn': + return knn() + else: + raise ValueError(f'Unknown attack classifier {attack_classifier}.') + + +def logistic_regression(verbose: int = 0, n_jobs: int = 1): + """Setup a logistic regression pipeline with cross-validation.""" + lr = linear_model.LogisticRegression(solver='lbfgs') + param_grid = { + 'C': np.logspace(-4, 2, 10), + } + pipe = model_selection.GridSearchCV( + lr, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False, + verbose=verbose) + return pipe + + +def random_forest(verbose: int = 0, n_jobs: int = 1): + """Setup a random forest pipeline with cross-validation.""" + rf = ensemble.RandomForestClassifier() + + n_estimators = [100] + max_features = ['auto', 'sqrt'] + max_depth = [5, 10, 20] + max_depth.append(None) + min_samples_split = [2, 5, 10] + min_samples_leaf = [1, 2, 4] + random_grid = {'n_estimators': n_estimators, + 'max_features': max_features, + 'max_depth': max_depth, + 'min_samples_split': min_samples_split, + 'min_samples_leaf': min_samples_leaf} + + pipe = model_selection.RandomizedSearchCV( + rf, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs, + iid=False, verbose=verbose) + return pipe + + +def mlp(verbose: int = 0, n_jobs: int = 1): + """Setup a MLP pipeline with cross-validation.""" + mlpmodel = neural_network.MLPClassifier() + + param_grid = { + 'hidden_layer_sizes': [(64,), (32, 32)], + 'solver': ['adam'], + 'alpha': [0.0001, 0.001, 0.01], + } + pipe = model_selection.GridSearchCV( + mlpmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False, + verbose=verbose) + return pipe + + +def knn(verbose: int = 0, n_jobs: int = 1): + """Setup a k-nearest neighbors pipeline with cross-validation.""" + knnmodel = neighbors.KNeighborsClassifier() + + param_grid = { + 'n_neighbors': [3, 5, 7], + } + pipe = model_selection.GridSearchCV( + knnmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False, + verbose=verbose) + return pipe diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/membership_inference_attack/utils.py new file mode 100644 index 0000000..08cf017 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils.py @@ -0,0 +1,218 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Utility functions for membership inference attacks.""" + +from typing import Text, Dict, Union, List, Any, Tuple + +import numpy as np +from sklearn import metrics + +ArrayDict = Dict[Text, np.ndarray] +Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] + +# ------------------------------------------------------------------------------ +# Utilities for managing result dictionaries +# ------------------------------------------------------------------------------ + + +def to_numpy(in_dict: Dict[Text, Any]) -> ArrayDict: + """Convert values of dict to numpy arrays. + + Warning: This may fail if the values cannot be converted to numpy arrays. + + Args: + in_dict: A dictionary mapping Text keys to values where the values must be + something that can be converted to a numpy array. + + Returns: + a dictionary with the same keys as input with all values converted to numpy + arrays + """ + return {k: np.array(v) for k, v in in_dict.items()} + + +def ensure_1d(in_dict: Dict[Text, Union[int, float, np.ndarray]]) -> ArrayDict: + """Ensure all values of a dictionary are at least 1D numpy arrays. + + Args: + in_dict: The input dictionary mapping Text keys to numpy arrays or numbers. + + Returns: + dictionary with same keys as in_dict and values converted to numpy arrays + with at least one dimension (i.e., pack scalars into arrays) + """ + return {k: np.atleast_1d(v) for k, v in in_dict.items()} + + +def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]: + """Prepend a prefix to all keys of a dictionary. + + Args: + in_dict: The input dictionary mapping Text keys to numpy arrays. + prefix: Text which to prepend to each key in in_dict + + Returns: + dictionary with same values as in_dict and all keys having prefix prepended + to them + """ + return {prefix + k: v for k, v in in_dict.items()} + + +# ------------------------------------------------------------------------------ +# Subsampling and data selection functionality +# ------------------------------------------------------------------------------ + + +def select_indices(in_dict: ArrayDict, indices: np.ndarray) -> ArrayDict: + """Subsample all values in the dictionary by the provided indices. + + Args: + in_dict: The input dictionary mapping Text keys to numpy array values. + indices: A numpy which can be used to index other arrays, specifying the + indices to subsample from in_dict values. + + Returns: + dictionary with same keys as in_dict and subsampled values + """ + return {k: v[indices] for k, v in in_dict.items()} + + +def merge_dictionaries(res: List[ArrayDict]) -> ArrayDict: + """Convert iterable of dicts to dict of iterables.""" + output = {k: np.empty(0) for k in res[0]} + for k in output: + output[k] = np.concatenate([r[k] for r in res if k in r], axis=0) + return output + + +def get_features(features: ArrayDict, + feature_name: Text, + top_k: int, + add_loss: bool = False) -> np.ndarray: + """Combine the specified features into one array. + + Args: + features: A dictionary containing all possible features. + feature_name: Which feature to use (logits or prob). + top_k: The number of the top features (of feature_name) to select. + add_loss: Whether to also add the loss as a feature. + + Returns: + combined numpy array with the selected features (n_examples, n_features) + """ + if top_k < 1: + raise ValueError('Must select at least one feature.') + feats = np.sort(features[feature_name], axis=-1)[:, :top_k] + if add_loss: + feats = np.concatenate((feats, features['loss'][:, np.newaxis]), axis=-1) + return feats + + +def subsample_to_balance(features: ArrayDict, random_state: int) -> ArrayDict: + """Subsample if necessary to balance labels.""" + train_idx = features['is_train'] == 1 + test_idx = np.logical_not(train_idx) + n0 = np.sum(test_idx) + n1 = np.sum(train_idx) + + if n0 < 20 or n1 < 20: + raise RuntimeError('Need at least 20 examples from training and test set.') + + np.random.seed(random_state) + + if n0 > n1: + use_idx = np.random.choice(np.where(test_idx)[0], n1, replace=False) + use_idx = np.concatenate((use_idx, np.where(train_idx)[0])) + features = {k: v[use_idx] for k, v in features.items()} + elif n0 < n1: + use_idx = np.random.choice(np.where(train_idx)[0], n0, replace=False) + use_idx = np.concatenate((use_idx, np.where(test_idx)[0])) + features = {k: v[use_idx] for k, v in features.items()} + + return features + + +def get_train_test_split(features: ArrayDict, add_loss: bool, + test_size: float) -> Dataset: + """Get training and test data split.""" + y = features['is_train'] + n_total = len(y) + n_test = int(test_size * n_total) + perm = np.random.permutation(len(y)) + test_idx = perm[:n_test] + train_idx = perm[n_test:] + y_train = y[train_idx] + y_test = y[test_idx] + + # We are using 10 top logits as a good default value if there are more than 10 + # classes. Typically, there is no significant amount of weight in more than + # 10 logits. + n_logits = min(features['logits'].shape[1], 10) + x = get_features(features, 'logits', n_logits, add_loss) + + x_train, x_test = x[train_idx], x[test_idx] + return (x_train, y_train), (x_test, y_test) + + +# ------------------------------------------------------------------------------ +# Computation of the attack metrics +# ------------------------------------------------------------------------------ + + +def compute_performance_metrics(true_labels: np.ndarray, + predictions: np.ndarray, + threshold: float = None) -> ArrayDict: + """Compute relevant classification performance metrics. + + The outout metrics are + 1.arrays of thresholds and corresponding true and false positives (fpr, tpr). + 2.auc area under fpr-tpr curve. + 3.advantage max difference between tpr and fpr. + 4.precision/recall/accuracy/f1_score if threshold arg is given. + + Args: + true_labels: True labels. + predictions: Predicted probabilities/scores. + threshold: The threshold to use on `predictions` binary classification. + + Returns: + A dictionary with relevant metrics which are fully described by their key. + """ + results = {} + if threshold is not None: + results.update({ + 'precision': + metrics.precision_score(true_labels, predictions > threshold), + 'recall': + metrics.recall_score(true_labels, predictions > threshold), + 'accuracy': + metrics.accuracy_score(true_labels, predictions > threshold), + 'f1_score': + metrics.f1_score(true_labels, predictions > threshold), + }) + + fpr, tpr, thresholds = metrics.roc_curve(true_labels, predictions) + auc = metrics.auc(fpr, tpr) + advantage = np.max(np.abs(tpr - fpr)) + + results.update({ + 'fpr': fpr, + 'tpr': tpr, + 'thresholds': thresholds, + 'auc': auc, + 'advantage': advantage, + }) + return ensure_1d(results) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py b/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py new file mode 100644 index 0000000..0b918a7 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py @@ -0,0 +1,105 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils.""" +from absl.testing import absltest + +import numpy as np + +from tensorflow_privacy.privacy.membership_inference_attack import utils + + +class UtilsTest(absltest.TestCase): + + def __init__(self, methodname): + """Initialize the test class.""" + super().__init__(methodname) + rng = np.random.RandomState(33) + logits = rng.uniform(low=0, high=1, size=(1000, 14)) + loss = rng.uniform(low=0, high=1, size=(1000,)) + is_train = rng.binomial(1, 0.7, size=(1000,)) + self.mydict = {'logits': logits, 'loss': loss, 'is_train': is_train} + + def test_compute_metrics(self): + """Test computation of attack metrics.""" + true = np.array([0, 0, 0, 1, 1, 1]) + pred = np.array([0.6, 0.9, 0.4, 0.8, 0.7, 0.2]) + + results = utils.compute_performance_metrics(true, pred, threshold=0.5) + + for k in ['precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr', + 'thresholds', 'auc', 'advantage']: + self.assertIn(k, results) + + np.testing.assert_almost_equal(results['accuracy'], 1. / 2.) + np.testing.assert_almost_equal(results['precision'], 2. / (2. + 2.)) + np.testing.assert_almost_equal(results['recall'], 2. / (2. + 1.)) + + def test_prepend_to_keys(self): + """Test prepending of text to keys of a dictionary.""" + mydict = utils.prepend_to_keys(self.mydict, 'test') + for k in mydict: + self.assertTrue(k.startswith('test')) + + def test_select_indices(self): + """Test selecting indices from dictionary with array values.""" + mydict = {'a': np.arange(10), 'b': np.linspace(0, 1, 10)} + + idx = np.arange(5) + mydictidx = utils.select_indices(mydict, idx) + np.testing.assert_allclose(mydictidx['a'], np.arange(5)) + np.testing.assert_allclose(mydictidx['b'], np.linspace(0, 1, 10)[:5]) + + idx = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) > 0.5 + mydictidx = utils.select_indices(mydict, idx) + np.testing.assert_allclose(mydictidx['a'], np.arange(0, 10, 2)) + np.testing.assert_allclose(mydictidx['b'], np.linspace(0, 1, 10)[0:10:2]) + + def test_get_features(self): + """Test extraction of features.""" + for k in [1, 5, 10, 15]: + for add_loss in [True, False]: + feats = utils.get_features( + self.mydict, 'logits', top_k=k, add_loss=add_loss) + k_selected = min(k, 14) + self.assertEqual(feats.shape, (1000, k_selected + int(add_loss))) + + def test_subsample_to_balance(self): + """Test subsampling of two arrays.""" + feats = utils.subsample_to_balance(self.mydict, random_state=23) + + train = np.sum(self.mydict['is_train']) + test = 1000 - train + n_chosen = min(train, test) + self.assertEqual(feats['logits'].shape, (2 * n_chosen, 14)) + self.assertEqual(feats['loss'].shape, (2 * n_chosen,)) + self.assertEqual(np.sum(feats['is_train']), n_chosen) + self.assertEqual(np.sum(1 - feats['is_train']), n_chosen) + + def test_get_data(self): + """Test train test split data generation.""" + for test_size in [0.2, 0.5, 0.8, 0.55555]: + (x_train, y_train), (x_test, y_test) = utils.get_train_test_split( + self.mydict, add_loss=True, test_size=test_size) + n_test = int(test_size * 1000) + n_train = 1000 - n_test + self.assertEqual(x_train.shape, (n_train, 11)) + self.assertEqual(y_train.shape, (n_train,)) + self.assertEqual(x_test.shape, (n_test, 11)) + self.assertEqual(y_test.shape, (n_test,)) + + +if __name__ == '__main__': + absltest.main()