Open sourcing membership inference attack.
PiperOrigin-RevId: 317958055
This commit is contained in:
parent
1fb9b80d90
commit
88dd8771bf
9 changed files with 2000 additions and 0 deletions
238
tensorflow_privacy/privacy/membership_inference_attack/README.md
Normal file
238
tensorflow_privacy/privacy/membership_inference_attack/README.md
Normal file
|
@ -0,0 +1,238 @@
|
||||||
|
# Membership inference attack functionality
|
||||||
|
|
||||||
|
The goal is to provide empirical tests of "how much information a machine
|
||||||
|
learning model has remembered about its training data". To this end, only the
|
||||||
|
outputs of the model are used (e.g., losses, logits, predictions). From those
|
||||||
|
alone, the attacks try to infer whether the corresponding inputs were part of
|
||||||
|
the training set.
|
||||||
|
|
||||||
|
> NOTE: Only the loss values are needed for some examples used during training
|
||||||
|
> and some examples that have not been used during training (e.g., some examples
|
||||||
|
> from the test set). No access to actual input data is needed. In case of
|
||||||
|
> classification models, one can additionally (or instead of losses) provide
|
||||||
|
> logits or output probabilities for stronger attacks.
|
||||||
|
|
||||||
|
The vulnerability of a model is measured via the area under the ROC-curve
|
||||||
|
(`auc`) or via max{|fpr - tpr|} (`advantage`) of the attack classifier. These
|
||||||
|
measures are very closely related.
|
||||||
|
|
||||||
|
## Highest level -- get attack summary
|
||||||
|
|
||||||
|
### Basic usage
|
||||||
|
|
||||||
|
On the highest level, there is the `run_all_attacks_and_create_summary`
|
||||||
|
function, which chooses sane default options to run a host of (fairly simple)
|
||||||
|
attacks behind the scenes (depending on which data is fed in), computes the most
|
||||||
|
important measures and returns a summary of the results as a string of english
|
||||||
|
language (as well as optionally a python dictionary containing all results with
|
||||||
|
descriptive keys).
|
||||||
|
|
||||||
|
> NOTE: The train and test sets are balanced internally, i.e., an equal number
|
||||||
|
> of in-training and out-of-training examples is chosen for the attacks
|
||||||
|
> (whichever has fewer examples). These are subsampled uniformly at random
|
||||||
|
> without replacement from the larger of the two.
|
||||||
|
|
||||||
|
The simplest possible usage is
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
|
|
||||||
|
# Evaluate your model on training and test examples to get
|
||||||
|
# loss_train shape: (n_train, )
|
||||||
|
# loss_test shape: (n_test, )
|
||||||
|
|
||||||
|
summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, return_dict=True)
|
||||||
|
print(results)
|
||||||
|
# -> {'auc': 0.7044,
|
||||||
|
# 'best_attacker_auc': 'all_thresh_loss_auc',
|
||||||
|
# 'advantage': 0.3116,
|
||||||
|
# 'best_attacker_auc': 'all_thresh_loss_advantage'}
|
||||||
|
```
|
||||||
|
|
||||||
|
> NOTE: The keyword argument `return_dict` specified whether in addition to the
|
||||||
|
> `summary` the function also returns a python dictionary with the results.
|
||||||
|
|
||||||
|
If the model is a classifier, the logits or output probabilities (i.e., the
|
||||||
|
softmax of logits) can also be provided to perform stronger attacks.
|
||||||
|
|
||||||
|
> NOTE: The `logits_train` and `logits_test` arguments can also be filled with
|
||||||
|
> output probabilities per class ("posteriors").
|
||||||
|
|
||||||
|
```python
|
||||||
|
# logits_train shape: (n_train, n_classes)
|
||||||
|
# logits_test shape: (n_test, n_classes)
|
||||||
|
|
||||||
|
summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train,
|
||||||
|
logits_test, return_dict=True)
|
||||||
|
print(results)
|
||||||
|
# -> {'auc': 0.5382,
|
||||||
|
# 'best_attacker_auc': 'all_lr_logits_loss_test_auc',
|
||||||
|
# 'advantage': 0.0572,
|
||||||
|
# 'best_attacker_auc': 'all_mlp_logits_loss_test_advantage'}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `summary` will be a string in natural language describing the results in
|
||||||
|
more detail, e.g.,
|
||||||
|
|
||||||
|
```
|
||||||
|
========== AUC ==========
|
||||||
|
The best attack (all_lr_logits_loss_test_auc) achieved an auc of 0.5382.
|
||||||
|
|
||||||
|
========== ADVANTAGE ==========
|
||||||
|
The best attack (all_mlp_logits_loss_test_advantage) achieved an advantage of 0.0572.
|
||||||
|
```
|
||||||
|
|
||||||
|
Similarly, we can run attacks on the logits alone, without access to losses:
|
||||||
|
|
||||||
|
```python
|
||||||
|
summary, results = mia.run_all_attacks_and_create_summary(logits_train=logits_train,
|
||||||
|
logits_test=logits_test,
|
||||||
|
return_dict=True)
|
||||||
|
print(results)
|
||||||
|
# -> {'auc': 0.9278,
|
||||||
|
# 'best_attacker_auc': 'all_rf_logits_test_auc',
|
||||||
|
# 'advantage': 0.6991,
|
||||||
|
# 'best_attacker_auc': 'all_rf_logits_test_advantage'}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced usage
|
||||||
|
|
||||||
|
Finally, if we also have access to the true labels of the training and test
|
||||||
|
inputs, we can run the attacks for each class separately. If labels *and* logits
|
||||||
|
are provided, attacks only for misclassified (typically uncertain) examples are
|
||||||
|
also performed.
|
||||||
|
|
||||||
|
```python
|
||||||
|
summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train,
|
||||||
|
logits_test, labels_train, labels_test,
|
||||||
|
return_dict=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Here, we now also get as output the class with the maximal vulnerability
|
||||||
|
according to our metrics (`max_vuln_class_auc`, `max_vuln_class_advantage`)
|
||||||
|
together with the corresponding values (`class_<CLASS>_auc`,
|
||||||
|
`class_<CLASS>_advantage`). The same values exist in the `results` dictionary
|
||||||
|
for `min` instead of `max`, i.e., the least vulnerable classes. Moreover, the
|
||||||
|
gap between the maximum and minimum values (`max_class_gap_auc`,
|
||||||
|
`max_class_gap_advantage`) is also provided. Similarly, the vulnerability
|
||||||
|
metrics when the attacks are restricted to the misclassified examples
|
||||||
|
(`misclassified_auc`, `misclassified_advantage`) are also shown. Finally, the
|
||||||
|
results also contain the number of examples in each of these groups, i.e.,
|
||||||
|
within each of the reported classes as well as the number of misclassified
|
||||||
|
examples. The final `results` dictionary is of the form
|
||||||
|
|
||||||
|
```
|
||||||
|
{'auc': 0.9181,
|
||||||
|
'best_attacker_auc': 'all_rf_logits_loss_test_auc',
|
||||||
|
'advantage': 0.6915,
|
||||||
|
'best_attacker_advantage': 'all_rf_logits_loss_test_advantage',
|
||||||
|
'max_class_gap_auc': 0.254,
|
||||||
|
'class_5_auc': 0.9512,
|
||||||
|
'class_3_auc': 0.6972,
|
||||||
|
'max_vuln_class_auc': 5,
|
||||||
|
'min_vuln_class_auc': 3,
|
||||||
|
'max_class_gap_advantage': 0.5073,
|
||||||
|
'class_0_advantage': 0.8086,
|
||||||
|
'class_3_advantage': 0.3013,
|
||||||
|
'max_vuln_class_advantage': 0,
|
||||||
|
'min_vuln_class_advantage': 3,
|
||||||
|
'misclassified_n_examples': 4513.0,
|
||||||
|
'class_0_n_examples': 899.0,
|
||||||
|
'class_1_n_examples': 900.0,
|
||||||
|
'class_2_n_examples': 931.0,
|
||||||
|
'class_3_n_examples': 893.0,
|
||||||
|
'class_4_n_examples': 960.0,
|
||||||
|
'class_5_n_examples': 884.0}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setting the precision of the reported results
|
||||||
|
|
||||||
|
Finally, `run_all_attacks_and_create_summary` takes one extra keyword argument
|
||||||
|
`decimals`, expecting a positive integer. This sets the precision of all result
|
||||||
|
values as the number of decimals to report. It defaults to 4.
|
||||||
|
|
||||||
|
## Run all attacks and get all outputs
|
||||||
|
|
||||||
|
With the `run_all_attacks` function, one can run all implemented attacks on all
|
||||||
|
possible subsets of the data (all examples, split by class, split by confidence
|
||||||
|
deciles, misclassified only). This function returns a relatively large
|
||||||
|
dictionary with all attack results. This is the most detailed information one
|
||||||
|
could get about these types of membership inference attacks (besides plots for
|
||||||
|
each attack, see next section.) This is useful if you know exactly what you're
|
||||||
|
looking for.
|
||||||
|
|
||||||
|
> NOTE: The `run_all_attacks` function takes as an additional argument which
|
||||||
|
> trained attackers to run. In the `run_all_attacks_and_create_summary`, only
|
||||||
|
> logistic regression (`lr`) is trained as a binary classifier to distinguish
|
||||||
|
> in-training form out-of-training examples. In addition, with the
|
||||||
|
> `attack_classifiers` argument, one can add multi-layered perceptrons (`mlp`),
|
||||||
|
> random forests (`rf`), and k-nearest-neighbors (`knn`) or any subset thereof
|
||||||
|
> for the attack models. Note that these classifiers may not converge.
|
||||||
|
|
||||||
|
```python
|
||||||
|
mia.run_all_attacks(loss_train, loss_test, logits_train, logits_test,
|
||||||
|
labels_train, labels_test,
|
||||||
|
attack_classifiers=('lr', 'mlp', 'rf', 'knn'))
|
||||||
|
```
|
||||||
|
|
||||||
|
Again, `run_all_attacks` can be called on all combinations of losses, logits,
|
||||||
|
probabilities, and labels as long as at least either losses or logits
|
||||||
|
(probabilities) are provided.
|
||||||
|
|
||||||
|
## Fine grained control over individual attacks and plots
|
||||||
|
|
||||||
|
The `run_attack` function exposes the underlying workhorse of the
|
||||||
|
`run_all_attacks` and `run_all_attacks_and_create_summary` functionality. It
|
||||||
|
allows for fine grained control of which attacks to run individually.
|
||||||
|
|
||||||
|
As another key feature, this function also exposes options to store receiver
|
||||||
|
operator curve plots for the different attacks as well as histograms of losses
|
||||||
|
or the maximum logits/probabilities. Finally, we can also store all results
|
||||||
|
(including the values to reproduce the plots) to colossus.
|
||||||
|
|
||||||
|
All options are explained in detail in the doc string of the `run_attack`
|
||||||
|
function.
|
||||||
|
|
||||||
|
For example, to run a simple threshold attack on the losses only and store plots
|
||||||
|
and result data to colossus, run
|
||||||
|
|
||||||
|
```python
|
||||||
|
data_path = '/Users/user/Desktop/test/' # set to None to not store data
|
||||||
|
figure_path = '/Users/user/Desktop/test/' # set to None to not store figures
|
||||||
|
|
||||||
|
mia.attack(loss_train=loss_train,
|
||||||
|
loss_test=loss_test,
|
||||||
|
metric='auc',
|
||||||
|
output_directory=data_path,
|
||||||
|
figure_directory=figure_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
Among other things, the `run_attack` functionality allows to control:
|
||||||
|
|
||||||
|
* which metrics to output (`metric` argument, using `auc` or `advantage` or
|
||||||
|
both)
|
||||||
|
* which classifiers (logistic regression, multi-layered perceptrons, random
|
||||||
|
forests) to train as attackers beyond the simple threshold attacks
|
||||||
|
(`attack_classifiers`)
|
||||||
|
* to only attack a specific (set of) classes (`by_class`)
|
||||||
|
* to only attack specific percentiles of the data (`by_percentile`).
|
||||||
|
Percentiles here are computed by looking at the largest logit or probability
|
||||||
|
for each example, i.e., how confident the model is in its prediction.
|
||||||
|
* to only attack the misclassified examples (`only_misclassified`)
|
||||||
|
* not to balance examples between the in-training and out-of-training examples
|
||||||
|
using `balance`. By default an equal number of examples from train and test
|
||||||
|
are selected for the attacks (whichever is smaller).
|
||||||
|
* the test set size for trained attacks (`test_size`). When a classifier is
|
||||||
|
trained to distinguish between train and test examples, a train-test split
|
||||||
|
for that classifier itself is required.
|
||||||
|
* for the train-test split as well as for the class balancing randomness is
|
||||||
|
used with a seed specified by `random_state`.
|
||||||
|
|
||||||
|
## Contact
|
||||||
|
|
||||||
|
Reach out to tf-privacy@google.com and let us know how you’re using this module.
|
||||||
|
We’re keen on hearing your stories, feedback, and suggestions!
|
||||||
|
|
||||||
|
## Copyright
|
||||||
|
|
||||||
|
Copyright 2020 - Google LLC
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -0,0 +1,716 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Code that runs membership inference attacks based on the model outputs."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
from typing import Text, Dict, Iterable, Tuple, Union, Any
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
import numpy as np
|
||||||
|
from scipy import special
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import plotting
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import trained_attack_models
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import utils
|
||||||
|
|
||||||
|
from os import mkdir
|
||||||
|
|
||||||
|
ArrayDict = Dict[Text, np.ndarray]
|
||||||
|
FloatDict = Dict[Text, float]
|
||||||
|
AnyDict = Dict[Text, Any]
|
||||||
|
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||||
|
MetricNames = Union[Text, Iterable[Text]]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vulnerabilities(result: ArrayDict, metrics: MetricNames) -> FloatDict:
|
||||||
|
"""Gets the vulnerabilities according to the chosen metrics for all attacks."""
|
||||||
|
vulns = {}
|
||||||
|
if isinstance(metrics, str):
|
||||||
|
metrics = [metrics]
|
||||||
|
for k in result:
|
||||||
|
for metric in metrics:
|
||||||
|
if k.endswith(metric.lower()) or k.endswith('n_examples'):
|
||||||
|
vulns[k] = float(result[k])
|
||||||
|
return vulns
|
||||||
|
|
||||||
|
|
||||||
|
def _get_maximum_vulnerability(
|
||||||
|
attack_result: FloatDict,
|
||||||
|
metrics: MetricNames,
|
||||||
|
filterby: Text = '') -> Dict[Text, Dict[Text, Union[Text, float]]]:
|
||||||
|
"""Returns the worst vulnerability according to the chosen metrics of all attacks."""
|
||||||
|
vulns = {}
|
||||||
|
if isinstance(metrics, str):
|
||||||
|
metrics = [metrics]
|
||||||
|
for metric in metrics:
|
||||||
|
best_attack_value = -np.inf
|
||||||
|
for k in attack_result:
|
||||||
|
if (k.startswith(filterby.lower()) and k.endswith(metric.lower()) and
|
||||||
|
'train' not in k):
|
||||||
|
if float(attack_result[k]) > best_attack_value:
|
||||||
|
best_attack_value = attack_result[k]
|
||||||
|
best_attacker = k
|
||||||
|
if best_attack_value > -np.inf:
|
||||||
|
newkey = filterby + '-' + metric if filterby else metric
|
||||||
|
vulns[newkey] = {'value': best_attack_value, 'attacker': best_attacker}
|
||||||
|
return vulns
|
||||||
|
|
||||||
|
|
||||||
|
def _get_maximum_class_gap_or_none(result: FloatDict,
|
||||||
|
metrics: MetricNames) -> FloatDict:
|
||||||
|
"""Returns the biggest and smallest vulnerability and the gap across classes."""
|
||||||
|
gaps = {}
|
||||||
|
if isinstance(metrics, str):
|
||||||
|
metrics = [metrics]
|
||||||
|
for metric in metrics:
|
||||||
|
hi = -np.inf
|
||||||
|
lo = np.inf
|
||||||
|
hi_idx, lo_idx = -1, -1
|
||||||
|
for k in result:
|
||||||
|
if (k.startswith('class') and k.endswith(metric.lower()) and
|
||||||
|
'train' not in k):
|
||||||
|
if float(result[k]) > hi:
|
||||||
|
hi = float(result[k])
|
||||||
|
hi_idx = int(re.findall(r'class_(\d+)_', k)[0])
|
||||||
|
if float(result[k]) < lo:
|
||||||
|
lo = float(result[k])
|
||||||
|
lo_idx = int(re.findall(r'class_(\d+)_', k)[0])
|
||||||
|
if lo - hi < np.inf:
|
||||||
|
gaps['max_class_gap_' + metric] = hi - lo
|
||||||
|
gaps[f'class_{hi_idx}_' + metric] = hi
|
||||||
|
gaps[f'class_{lo_idx}_' + metric] = lo
|
||||||
|
gaps['max_vuln_class_' + metric] = hi_idx
|
||||||
|
gaps['min_vuln_class_' + metric] = lo_idx
|
||||||
|
return gaps
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Attacks
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _run_threshold_loss_attack(features: ArrayDict,
|
||||||
|
figure_file_prefix: Text = '',
|
||||||
|
figure_directory: Text = None) -> ArrayDict:
|
||||||
|
"""Runs the threshold attack on the loss."""
|
||||||
|
logging.info('Run threshold attack on loss...')
|
||||||
|
is_train = features['is_train']
|
||||||
|
attack_prefix = 'thresh_loss'
|
||||||
|
tmp_results = utils.compute_performance_metrics(is_train, -features['loss'])
|
||||||
|
if figure_directory is not None:
|
||||||
|
figpath = os.path.join(figure_directory,
|
||||||
|
figure_file_prefix + attack_prefix + '.png')
|
||||||
|
plotting.save_plot(
|
||||||
|
plotting.plot_curve_with_area(
|
||||||
|
tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'),
|
||||||
|
figpath)
|
||||||
|
figpath = os.path.join(figure_directory,
|
||||||
|
figure_file_prefix + attack_prefix + '_hist.png')
|
||||||
|
plotting.save_plot(
|
||||||
|
plotting.plot_histograms(
|
||||||
|
features['loss'][is_train == 1],
|
||||||
|
features['loss'][is_train == 0],
|
||||||
|
xlabel='loss'), figpath)
|
||||||
|
return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
|
||||||
|
|
||||||
|
|
||||||
|
def _run_threshold_attack_maxlogit(features: ArrayDict,
|
||||||
|
figure_file_prefix: Text = '',
|
||||||
|
figure_directory: Text = None) -> ArrayDict:
|
||||||
|
"""Runs the threshold attack on the maximum logit."""
|
||||||
|
is_train = features['is_train']
|
||||||
|
preds = np.max(features['logits'], axis=-1)
|
||||||
|
tmp_results = utils.compute_performance_metrics(is_train, preds)
|
||||||
|
attack_prefix = 'thresh_maxlogit'
|
||||||
|
if figure_directory is not None:
|
||||||
|
figpath = os.path.join(figure_directory,
|
||||||
|
figure_file_prefix + attack_prefix + '.png')
|
||||||
|
plotting.save_plot(
|
||||||
|
plotting.plot_curve_with_area(
|
||||||
|
tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'),
|
||||||
|
figpath)
|
||||||
|
figpath = os.path.join(figure_directory,
|
||||||
|
figure_file_prefix + attack_prefix + '_hist.png')
|
||||||
|
plotting.save_plot(
|
||||||
|
plotting.plot_histograms(
|
||||||
|
preds[is_train == 1], preds[is_train == 0], xlabel='loss'), figpath)
|
||||||
|
return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
|
||||||
|
|
||||||
|
|
||||||
|
def _run_trained_attack(attack_classifier: Text,
|
||||||
|
data: Dataset,
|
||||||
|
attack_prefix: Text,
|
||||||
|
figure_file_prefix: Text = '',
|
||||||
|
figure_directory: Text = None) -> ArrayDict:
|
||||||
|
"""Train a classifier for attack and evaluate it."""
|
||||||
|
# Train the attack classifier
|
||||||
|
(x_train, y_train), (x_test, y_test) = data
|
||||||
|
clf_model = trained_attack_models.choose_model(attack_classifier)
|
||||||
|
clf_model.fit(x_train, y_train)
|
||||||
|
|
||||||
|
# Calculate training set metrics
|
||||||
|
pred_train = clf_model.predict_proba(x_train)[:, clf_model.classes_ == 1]
|
||||||
|
results = utils.prepend_to_keys(
|
||||||
|
utils.compute_performance_metrics(y_train, pred_train),
|
||||||
|
attack_prefix + 'train_')
|
||||||
|
|
||||||
|
# Calculate test set metrics
|
||||||
|
pred_test = clf_model.predict_proba(x_test)[:, clf_model.classes_ == 1]
|
||||||
|
results.update(
|
||||||
|
utils.prepend_to_keys(
|
||||||
|
utils.compute_performance_metrics(y_test, pred_test),
|
||||||
|
attack_prefix + 'test_'))
|
||||||
|
|
||||||
|
if figure_directory is not None:
|
||||||
|
figpath = os.path.join(figure_directory,
|
||||||
|
figure_file_prefix + attack_prefix[:-1] + '.png')
|
||||||
|
plotting.save_plot(
|
||||||
|
plotting.plot_curve_with_area(
|
||||||
|
results[attack_prefix + 'test_fpr'],
|
||||||
|
results[attack_prefix + 'test_tpr'],
|
||||||
|
xlabel='FPR',
|
||||||
|
ylabel='TPR'), figpath)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _run_attacks_and_plot(features: ArrayDict,
|
||||||
|
attacks: Iterable[Text],
|
||||||
|
attack_classifiers: Iterable[Text],
|
||||||
|
balance: bool,
|
||||||
|
test_size: float,
|
||||||
|
random_state: int,
|
||||||
|
figure_file_prefix: Text = '',
|
||||||
|
figure_directory: Text = None) -> ArrayDict:
|
||||||
|
"""Runs the specified attacks on the provided data."""
|
||||||
|
if balance:
|
||||||
|
try:
|
||||||
|
features = utils.subsample_to_balance(features, random_state)
|
||||||
|
except RuntimeError:
|
||||||
|
logging.info('Not enough remaining data for attack: Empty results.')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
# -------------------- Simple threshold attacks
|
||||||
|
if 'thresh_loss' in attacks:
|
||||||
|
result.update(
|
||||||
|
_run_threshold_loss_attack(features, figure_file_prefix,
|
||||||
|
figure_directory))
|
||||||
|
|
||||||
|
if 'thresh_maxlogit' in attacks:
|
||||||
|
result.update(
|
||||||
|
_run_threshold_attack_maxlogit(features, figure_file_prefix,
|
||||||
|
figure_directory))
|
||||||
|
|
||||||
|
# -------------------- Run learned attacks
|
||||||
|
# TODO(b/157632603): Add a prefix (for example 'trained_') for attacks which
|
||||||
|
# use classifiers to distinguish from threshould attacks.
|
||||||
|
if 'logits' in attacks:
|
||||||
|
data = utils.get_train_test_split(
|
||||||
|
features, add_loss=False, test_size=test_size)
|
||||||
|
for clf in attack_classifiers:
|
||||||
|
logging.info('Train %s on %d logits', clf, data[0][0].shape[1])
|
||||||
|
attack_prefix = f'{clf}_logits_'
|
||||||
|
result.update(
|
||||||
|
_run_trained_attack(clf, data, attack_prefix, figure_file_prefix,
|
||||||
|
figure_directory))
|
||||||
|
|
||||||
|
if 'logits_loss' in attacks:
|
||||||
|
data = utils.get_train_test_split(
|
||||||
|
features, add_loss=True, test_size=test_size)
|
||||||
|
for clf in attack_classifiers:
|
||||||
|
logging.info('Train %s on %d logits + loss', clf, data[0][0].shape[1])
|
||||||
|
attack_prefix = f'{clf}_logits_loss_'
|
||||||
|
result.update(
|
||||||
|
_run_trained_attack(clf, data, attack_prefix, figure_file_prefix,
|
||||||
|
figure_directory))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_attack(loss_train: np.ndarray = None,
|
||||||
|
loss_test: np.ndarray = None,
|
||||||
|
logits_train: np.ndarray = None,
|
||||||
|
logits_test: np.ndarray = None,
|
||||||
|
labels_train: np.ndarray = None,
|
||||||
|
labels_test: np.ndarray = None,
|
||||||
|
attack_classifiers: Iterable[Text] = None,
|
||||||
|
only_misclassified: bool = False,
|
||||||
|
by_class: Union[bool, Iterable[int], int] = False,
|
||||||
|
by_percentile: Union[bool, Iterable[int], int] = False,
|
||||||
|
figure_directory: Text = None,
|
||||||
|
output_directory: Text = None,
|
||||||
|
metric: MetricNames = 'auc',
|
||||||
|
balance: bool = True,
|
||||||
|
test_size: float = 0.2,
|
||||||
|
random_state: int = 0) -> FloatDict:
|
||||||
|
"""Run membership inference attack(s).
|
||||||
|
|
||||||
|
Based only on specific outputs of a machine learning model on some examples
|
||||||
|
used for training (train) and some examples not used for training (test), run
|
||||||
|
membership inference attacks that try to discriminate training from test
|
||||||
|
inputs based only on the model outputs.
|
||||||
|
While all inputs are optional, at least one train/test pair is required to run
|
||||||
|
any attacks (either losses or logits/probabilities).
|
||||||
|
Note that one can equally provide output probabilities instead of logits in
|
||||||
|
the logits_train / logits_test arguments.
|
||||||
|
|
||||||
|
We measure the vulnerability of the model via the area under the ROC-curve
|
||||||
|
(auc) or via max |fpr - tpr| (advantage) of the attack classifier. These
|
||||||
|
measures are very closely related and may look almost indistinguishable.
|
||||||
|
|
||||||
|
This function provides relatively fine grained control and outputs detailed
|
||||||
|
results. For a higher-level wrapper with sane internal default settings and
|
||||||
|
distilled output results, see `run_all_attacks`.
|
||||||
|
|
||||||
|
Via the `figure_directory` argument and the `output_directory` argument more
|
||||||
|
detailed information as well as roc-curve plots can optionally be stored to
|
||||||
|
disk.
|
||||||
|
|
||||||
|
If `loss_train` and `loss_test` are provided we run:
|
||||||
|
- simple threshold attack on the loss
|
||||||
|
|
||||||
|
If `logits_train` and `logits_test` are provided we run:
|
||||||
|
- simple threshold attack on the top logit
|
||||||
|
- if `attack_classifiers` is not None and no losses are provided: train the
|
||||||
|
specified classifiers on the top 10 logits (or all logits if there are
|
||||||
|
less than 10)
|
||||||
|
- if `attack_classifiers` is not None and losses are provided: train the
|
||||||
|
specified classifiers on the top 10 logits (or all logits if there are
|
||||||
|
less than 10) and the loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_train: A 1D array containing the individual scalar losses for examples
|
||||||
|
used during training.
|
||||||
|
loss_test: A 1D array containing the individual scalar losses for examples
|
||||||
|
not used during training.
|
||||||
|
logits_train: A 2D array (n_train, n_classes) of the individual logits or
|
||||||
|
output probabilities of examples used during training.
|
||||||
|
logits_test: A 2D array (n_test, n_classes) of the individual logits or
|
||||||
|
output probabilities of examples not used during training.
|
||||||
|
labels_train: The true labels of the training examples. Labels are only
|
||||||
|
needed when `by_class` is specified (i.e., not False).
|
||||||
|
labels_test: The true labels of the test examples. Labels are only needed
|
||||||
|
when `by_class` is specified (i.e., not False).
|
||||||
|
attack_classifiers: Attack classifiers to train beyond simple thresholding
|
||||||
|
that require training a simple binary ML classifier. This argument is
|
||||||
|
ignored if logits are not provided. Classifiers can be 'lr' for logistic
|
||||||
|
regression, 'mlp' for multi-layered perceptron, 'rf' for random forests,
|
||||||
|
or 'knn' for k-nearest-neighbors. If 'None', don't train classifiers
|
||||||
|
beyond simple thresholding.
|
||||||
|
only_misclassified: Run and evaluate attacks only on misclassified examples.
|
||||||
|
Must specify `labels_train`, `labels_test`, `logits_train` and
|
||||||
|
`logits_test` to use this. If this is True, `by_class` and `by_percentile`
|
||||||
|
are ignored.
|
||||||
|
by_class: This argument determines whether attacks are run on the entire
|
||||||
|
data, or on examples grouped by their class label. If `True`, all attacks
|
||||||
|
are run separately for each class. If `by_class` is a single integer, run
|
||||||
|
attacks for this class only. If `by_class` is an iterable of integers, run
|
||||||
|
all attacks for each of the specified class labels separately. Only used
|
||||||
|
if `labels_train` and `labels_test` are specified. If `by_class` is
|
||||||
|
specified (not False), `by_percentile` is ignored. Ignored if
|
||||||
|
`only_misclassified` is True.
|
||||||
|
by_percentile: This argument determines whether attacks are run on the
|
||||||
|
entire data, or separately for examples where the most likely class
|
||||||
|
prediction is within a given percentile of all maximum predicitons. If
|
||||||
|
`True`, all attacks are run separately for the examples with max
|
||||||
|
probabilities within the ten deciles. If `by_precentile` is a single int
|
||||||
|
between 0 and 100, run attacks only for examples with confidence within
|
||||||
|
this percentile. If `by_percentile` is an iterable of ints between 0 and
|
||||||
|
100, run all attacks for each of the specified percentiles separately.
|
||||||
|
Ignored if `by_class` is specified. Ignored if `logits_train` and
|
||||||
|
`logits_test` are not specified. Ignored if `only_misclassified` is True.
|
||||||
|
figure_directory: Where to store ROC-curve plots and histograms. If `None`,
|
||||||
|
don't create plots.
|
||||||
|
output_directory: Where to store detailed result data for all run attacks.
|
||||||
|
If `None`, don't store detailed result data.
|
||||||
|
metric: Available vulnerability metrics are 'auc' or 'advantage' for the
|
||||||
|
area under the ROC curve or the advantage (max |tpr - fpr|). Specify
|
||||||
|
either one of them or both.
|
||||||
|
balance: Whether to use the same number of train and test samples (by
|
||||||
|
randomly subsampling whichever happens to be larger).
|
||||||
|
test_size: The fraction of the input data to use for the evaluation of
|
||||||
|
trained ML attacks. This argument is ignored, if either attack_classifiers
|
||||||
|
is None, or no logits are provided.
|
||||||
|
random_state: Random seed for reproducibility. Only used if attack models
|
||||||
|
are trained.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results: Dictionary with the chosen vulnerability metric(s) for all ran
|
||||||
|
attacks.
|
||||||
|
"""
|
||||||
|
attacks = []
|
||||||
|
features = {}
|
||||||
|
# ---------- Check available data ----------
|
||||||
|
if ((loss_train is None or loss_test is None) and
|
||||||
|
(logits_train is None or logits_test is None)):
|
||||||
|
raise ValueError(
|
||||||
|
'Need at least train and test for loss or train and test for logits.')
|
||||||
|
|
||||||
|
# ---------- If losses are provided ----------
|
||||||
|
if loss_train is not None and loss_test is not None:
|
||||||
|
if loss_train.ndim != 1 or loss_test.ndim != 1:
|
||||||
|
raise ValueError('Losses must be 1D arrays.')
|
||||||
|
features['is_train'] = np.concatenate(
|
||||||
|
(np.ones(len(loss_train)), np.zeros(len(loss_test))),
|
||||||
|
axis=0).astype(int)
|
||||||
|
features['loss'] = np.concatenate((loss_train.ravel(), loss_test.ravel()),
|
||||||
|
axis=0)
|
||||||
|
attacks.append('thresh_loss')
|
||||||
|
|
||||||
|
# ---------- If logits are provided ----------
|
||||||
|
if logits_train is not None and logits_test is not None:
|
||||||
|
assert logits_train.ndim == 2 and logits_test.ndim == 2, \
|
||||||
|
'Logits must be 2D arrays.'
|
||||||
|
assert logits_train.shape[1] == logits_test.shape[1], \
|
||||||
|
'Train and test logits must agree along axis 1 (number of classes).'
|
||||||
|
if 'is_train' in features:
|
||||||
|
assert (loss_train.shape[0] == logits_train.shape[0] and
|
||||||
|
loss_test.shape[0] == logits_test.shape[0]), \
|
||||||
|
'Number of examples must match between loss and logits.'
|
||||||
|
else:
|
||||||
|
features['is_train'] = np.concatenate(
|
||||||
|
(np.ones(logits_train.shape[0]), np.zeros(logits_test.shape[0])),
|
||||||
|
axis=0).astype(int)
|
||||||
|
attacks.append('thresh_maxlogit')
|
||||||
|
features['logits'] = np.concatenate((logits_train, logits_test), axis=0)
|
||||||
|
if attack_classifiers:
|
||||||
|
attacks.append('logits')
|
||||||
|
if 'loss' in features:
|
||||||
|
attacks.append('logits_loss')
|
||||||
|
|
||||||
|
# ---------- If labels are provided ----------
|
||||||
|
if labels_train is not None and labels_test is not None:
|
||||||
|
if labels_train.ndim != 1 or labels_test.ndim != 1:
|
||||||
|
raise ValueError('Losses must be 1D arrays.')
|
||||||
|
if 'loss' in features:
|
||||||
|
assert (loss_train.shape[0] == labels_train.shape[0] and
|
||||||
|
loss_test.shape[0] == labels_test.shape[0]), \
|
||||||
|
'Number of examples must match between loss and labels.'
|
||||||
|
else:
|
||||||
|
assert (logits_train.shape[0] == labels_train.shape[0] and
|
||||||
|
logits_test.shape[0] == labels_test.shape[0]), \
|
||||||
|
'Number of examples must match between logits and labels.'
|
||||||
|
features['label'] = np.concatenate((labels_train, labels_test), axis=0)
|
||||||
|
|
||||||
|
# ---------- Data subsampling or filtering ----------
|
||||||
|
filtertype = None
|
||||||
|
filtervals = [None]
|
||||||
|
if only_misclassified:
|
||||||
|
if (labels_train is None or labels_test is None or logits_train is None or
|
||||||
|
logits_test is None):
|
||||||
|
raise ValueError('Must specify labels_train, labels_test, logits_train, '
|
||||||
|
'and logits_test for the only_misclassified option.')
|
||||||
|
filtertype = 'misclassified'
|
||||||
|
elif by_class:
|
||||||
|
if labels_train is None or labels_test is None:
|
||||||
|
raise ValueError('Must specify labels_train and labels_test when using '
|
||||||
|
'the by_class option.')
|
||||||
|
if isinstance(by_class, bool):
|
||||||
|
filtervals = list(set(labels_train) | set(labels_test))
|
||||||
|
elif isinstance(by_class, int):
|
||||||
|
filtervals = [by_class]
|
||||||
|
elif isinstance(by_class, collections.Iterable):
|
||||||
|
filtervals = list(by_class)
|
||||||
|
filtertype = 'class'
|
||||||
|
elif by_percentile:
|
||||||
|
if logits_train is None or logits_test is None:
|
||||||
|
raise ValueError('Must specify logits_train and logits_test when using '
|
||||||
|
'the by_percentile option.')
|
||||||
|
if isinstance(by_percentile, bool):
|
||||||
|
filtervals = list(range(10, 101, 10))
|
||||||
|
elif isinstance(by_percentile, int):
|
||||||
|
filtervals = [by_percentile]
|
||||||
|
elif isinstance(by_percentile, collections.Iterable):
|
||||||
|
filtervals = [int(percentile) for percentile in by_percentile]
|
||||||
|
filtertype = 'percentile'
|
||||||
|
|
||||||
|
# ---------- Need to create figure directory? ----------
|
||||||
|
if figure_directory is not None:
|
||||||
|
mkdir(figure_directory)
|
||||||
|
|
||||||
|
# ---------- Actually run attacks and plot if required ----------
|
||||||
|
logging.info('Selecting %s with values %s', filtertype, filtervals)
|
||||||
|
num = None
|
||||||
|
result = {}
|
||||||
|
for filterval in filtervals:
|
||||||
|
if filtertype is None:
|
||||||
|
tmp_features = features
|
||||||
|
elif filtertype == 'misclassified':
|
||||||
|
idx = features['label'] != np.argmax(features['logits'], axis=-1)
|
||||||
|
tmp_features = utils.select_indices(features, idx)
|
||||||
|
num = np.sum(idx)
|
||||||
|
elif filtertype == 'class':
|
||||||
|
idx = features['label'] == filterval
|
||||||
|
tmp_features = utils.select_indices(features, idx)
|
||||||
|
num = np.sum(idx)
|
||||||
|
elif filtertype == 'percentile':
|
||||||
|
certainty = np.max(special.softmax(features['logits'], axis=-1), axis=-1)
|
||||||
|
idx = certainty <= np.percentile(certainty, filterval)
|
||||||
|
tmp_features = utils.select_indices(features, idx)
|
||||||
|
|
||||||
|
prefix = f'{filtertype}_' if filtertype is not None else ''
|
||||||
|
prefix += f'{filterval}_' if filterval is not None else ''
|
||||||
|
tmp_result = _run_attacks_and_plot(tmp_features, attacks,
|
||||||
|
attack_classifiers, balance, test_size,
|
||||||
|
random_state, prefix, figure_directory)
|
||||||
|
if num is not None:
|
||||||
|
tmp_result['n_examples'] = float(num)
|
||||||
|
if tmp_result:
|
||||||
|
result.update(utils.prepend_to_keys(tmp_result, prefix))
|
||||||
|
|
||||||
|
# ---------- Store data ----------
|
||||||
|
if output_directory is not None:
|
||||||
|
mkdir(output_directory)
|
||||||
|
resultpath = os.path.join(output_directory, 'attack_results.npz')
|
||||||
|
logging.info('Store aggregate results at %s.', resultpath)
|
||||||
|
with open(resultpath, 'wb') as fp:
|
||||||
|
io_buffer = io.BytesIO()
|
||||||
|
np.savez(io_buffer, **result)
|
||||||
|
fp.write(io_buffer.getvalue())
|
||||||
|
|
||||||
|
return _get_vulnerabilities(result, metric)
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_attacks(loss_train: np.ndarray = None,
|
||||||
|
loss_test: np.ndarray = None,
|
||||||
|
logits_train: np.ndarray = None,
|
||||||
|
logits_test: np.ndarray = None,
|
||||||
|
labels_train: np.ndarray = None,
|
||||||
|
labels_test: np.ndarray = None,
|
||||||
|
attack_classifiers: Iterable[Text] = ('lr', 'mlp', 'rf',
|
||||||
|
'knn'),
|
||||||
|
decimals: Union[int, None] = 4) -> FloatDict:
|
||||||
|
"""Runs all possible membership inference attacks.
|
||||||
|
|
||||||
|
Check 'run_attack' for detailed information of how attacks are performed
|
||||||
|
and evaluated.
|
||||||
|
|
||||||
|
This function internally chooses sane default settings for all attacks and
|
||||||
|
returns all possible output combinations.
|
||||||
|
For fine grained control and partial attacks, please see `run_attack`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_train: A 1D array containing the individual scalar losses for examples
|
||||||
|
used during training.
|
||||||
|
loss_test: A 1D array containing the individual scalar losses for examples
|
||||||
|
not used during training.
|
||||||
|
logits_train: A 2D array (n_train, n_classes) of the individual logits or
|
||||||
|
output probabilities of examples used during training.
|
||||||
|
logits_test: A 2D array (n_test, n_classes) of the individual logits or
|
||||||
|
output probabilities of examples not used during training.
|
||||||
|
labels_train: The true labels of the training examples. Labels are only
|
||||||
|
needed when `by_class` is specified (i.e., not False).
|
||||||
|
labels_test: The true labels of the test examples. Labels are only needed
|
||||||
|
when `by_class` is specified (i.e., not False).
|
||||||
|
attack_classifiers: Which binary classifiers to train (in addition to simple
|
||||||
|
threshold attacks). This can include 'lr' (logistic regression), 'mlp'
|
||||||
|
(multi-layered perceptron), 'rf' (random forests), 'knn' (k-nearest
|
||||||
|
neighbors), which will be trained with cross validation to determine good
|
||||||
|
hyperparameters.
|
||||||
|
decimals: Round all float results to this number of decimals. If decimals is
|
||||||
|
None, don't round.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result: dictionary with all attack results
|
||||||
|
"""
|
||||||
|
metrics = ['auc', 'advantage']
|
||||||
|
|
||||||
|
# Entire data
|
||||||
|
result = run_attack(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
logits_train,
|
||||||
|
logits_test,
|
||||||
|
attack_classifiers=attack_classifiers,
|
||||||
|
metric=metrics)
|
||||||
|
result = utils.prepend_to_keys(result, 'all_')
|
||||||
|
|
||||||
|
# Misclassified examples
|
||||||
|
if (labels_train is not None and labels_test is not None and
|
||||||
|
logits_train is not None and logits_test is not None):
|
||||||
|
result.update(
|
||||||
|
run_attack(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
logits_train,
|
||||||
|
logits_test,
|
||||||
|
labels_train,
|
||||||
|
labels_test,
|
||||||
|
attack_classifiers=attack_classifiers,
|
||||||
|
only_misclassified=True,
|
||||||
|
metric=metrics))
|
||||||
|
|
||||||
|
# Split per class
|
||||||
|
if labels_train is not None and labels_test is not None:
|
||||||
|
result.update(
|
||||||
|
run_attack(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
logits_train,
|
||||||
|
logits_test,
|
||||||
|
labels_train,
|
||||||
|
labels_test,
|
||||||
|
by_class=True,
|
||||||
|
attack_classifiers=attack_classifiers,
|
||||||
|
metric=metrics))
|
||||||
|
|
||||||
|
# Different deciles
|
||||||
|
if logits_train is not None and logits_test is not None:
|
||||||
|
result.update(
|
||||||
|
run_attack(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
logits_train,
|
||||||
|
logits_test,
|
||||||
|
by_percentile=True,
|
||||||
|
attack_classifiers=attack_classifiers,
|
||||||
|
metric=metrics))
|
||||||
|
|
||||||
|
if decimals is not None:
|
||||||
|
result = {k: round(v, decimals) for k, v in result.items()}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_attacks_and_create_summary(
|
||||||
|
loss_train: np.ndarray = None,
|
||||||
|
loss_test: np.ndarray = None,
|
||||||
|
logits_train: np.ndarray = None,
|
||||||
|
logits_test: np.ndarray = None,
|
||||||
|
labels_train: np.ndarray = None,
|
||||||
|
labels_test: np.ndarray = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
decimals: Union[int, None] = 4) -> Union[Text, Tuple[Text, AnyDict]]:
|
||||||
|
"""Runs all possible membership inference attack(s) and distill results.
|
||||||
|
|
||||||
|
Check 'run_attack' for detailed information of how attacks are performed
|
||||||
|
and evaluated.
|
||||||
|
|
||||||
|
This function internally chooses sane default settings for all attacks and
|
||||||
|
returns all possible output combinations.
|
||||||
|
For fine grained control and partial attacks, please see `run_attack`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_train: A 1D array containing the individual scalar losses for examples
|
||||||
|
used during training.
|
||||||
|
loss_test: A 1D array containing the individual scalar losses for examples
|
||||||
|
not used during training.
|
||||||
|
logits_train: A 2D array (n_train, n_classes) of the individual logits or
|
||||||
|
output probabilities of examples used during training.
|
||||||
|
logits_test: A 2D array (n_test, n_classes) of the individual logits or
|
||||||
|
output probabilities of examples not used during training.
|
||||||
|
labels_train: The true labels of the training examples. Labels are only
|
||||||
|
needed when `by_class` is specified (i.e., not False).
|
||||||
|
labels_test: The true labels of the test examples. Labels are only needed
|
||||||
|
when `by_class` is specified (i.e., not False).
|
||||||
|
return_dict: Whether to also return a dictionary with the results summarized
|
||||||
|
in the summary string.
|
||||||
|
decimals: Round all float results to this number of decimals. If decimals is
|
||||||
|
None, don't round.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
summarystring: A string with natural language summary of the attacks. In the
|
||||||
|
summary string printed numbers will be rounded to `decimal` decimals if
|
||||||
|
provided, otherwise will round to 3 diits by default for readability.
|
||||||
|
result: a dictionary with all the distilled attack information summarized
|
||||||
|
in the summarystring
|
||||||
|
"""
|
||||||
|
summary = []
|
||||||
|
metrics = ['auc', 'advantage']
|
||||||
|
attack_classifiers = ['lr', 'rf', 'mlp', 'knn']
|
||||||
|
results = run_all_attacks(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
logits_train,
|
||||||
|
logits_test,
|
||||||
|
labels_train,
|
||||||
|
labels_test,
|
||||||
|
attack_classifiers=attack_classifiers,
|
||||||
|
decimals=None)
|
||||||
|
output = _get_maximum_vulnerability(results, metrics, filterby='all')
|
||||||
|
|
||||||
|
if decimals is not None:
|
||||||
|
strdec = decimals
|
||||||
|
else:
|
||||||
|
strdec = 4
|
||||||
|
|
||||||
|
for metric in metrics:
|
||||||
|
summary.append(f'========== {metric.upper()} ==========')
|
||||||
|
best_value = output['all-' + metric]['value']
|
||||||
|
best_attacker = output['all-' + metric]['attacker']
|
||||||
|
summary.append(f'The best attack ({best_attacker}) achieved an {metric} of '
|
||||||
|
f'{best_value:.{strdec}f}.')
|
||||||
|
summary.append('')
|
||||||
|
|
||||||
|
classgap = _get_maximum_class_gap_or_none(results, metrics)
|
||||||
|
if classgap:
|
||||||
|
output.update(classgap)
|
||||||
|
for metric in metrics:
|
||||||
|
summary.append(f'========== {metric.upper()} per class ==========')
|
||||||
|
hi_idx = output[f'max_vuln_class_{metric}']
|
||||||
|
lo_idx = output[f'min_vuln_class_{metric}']
|
||||||
|
hi = output[f'class_{hi_idx}_{metric}']
|
||||||
|
lo = output[f'class_{lo_idx}_{metric}']
|
||||||
|
gap = output[f'max_class_gap_{metric}']
|
||||||
|
summary.append(f'The most vulnerable class {hi_idx} has {metric} of '
|
||||||
|
f'{hi:.{strdec}f}.')
|
||||||
|
summary.append(f'The least vulnerable class {lo_idx} has {metric} of '
|
||||||
|
f'{lo:.{strdec}f}.')
|
||||||
|
summary.append(f'=> The maximum gap between class vulnerabilities is '
|
||||||
|
f'{gap:.{strdec}f}.')
|
||||||
|
summary.append('')
|
||||||
|
|
||||||
|
misclassified = _get_maximum_vulnerability(
|
||||||
|
results, metrics, filterby='misclassified')
|
||||||
|
if misclassified:
|
||||||
|
for metric in metrics:
|
||||||
|
best_attacker = misclassified['misclassified-' + metric]['attacker']
|
||||||
|
summary.append(f'========== {metric.upper()} for misclassified '
|
||||||
|
'==========')
|
||||||
|
summary.append('Among misclassified examples, the best attack '
|
||||||
|
f'({best_attacker}) achieved an {metric} of '
|
||||||
|
f'{best_value:.{strdec}f}.')
|
||||||
|
summary.append('')
|
||||||
|
output.update(misclassified)
|
||||||
|
|
||||||
|
n_examples = {k: v for k, v in results.items() if k.endswith('n_examples')}
|
||||||
|
if n_examples:
|
||||||
|
output.update(n_examples)
|
||||||
|
|
||||||
|
# Flatten remaining dicts in output
|
||||||
|
fresh_output = {}
|
||||||
|
for k, v in output.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
if k.startswith('all'):
|
||||||
|
fresh_output[k[4:]] = v['value']
|
||||||
|
fresh_output['best_attacker_' + k[4:]] = v['attacker']
|
||||||
|
else:
|
||||||
|
fresh_output[k] = v
|
||||||
|
output = fresh_output
|
||||||
|
|
||||||
|
if decimals is not None:
|
||||||
|
for k, v in output.items():
|
||||||
|
if isinstance(v, float):
|
||||||
|
output[k] = round(v, decimals)
|
||||||
|
|
||||||
|
summary = '\n'.join(summary)
|
||||||
|
if return_dict:
|
||||||
|
return summary, output
|
||||||
|
else:
|
||||||
|
return summary
|
|
@ -0,0 +1,307 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
|
||||||
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
|
|
||||||
|
|
||||||
|
def get_result_dict():
|
||||||
|
"""Get an example result dictionary."""
|
||||||
|
return {
|
||||||
|
'test_n_examples': np.ones(1),
|
||||||
|
'test_examples': np.zeros(1),
|
||||||
|
'test_auc': np.ones(1),
|
||||||
|
'test_advantage': np.ones(1),
|
||||||
|
'all_0-metric': np.array([1]),
|
||||||
|
'all_1-metric': np.array([2]),
|
||||||
|
'test_2-metric': np.array([3]),
|
||||||
|
'test_score': np.array([4]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_inputs():
|
||||||
|
"""Get example inputs for attacks."""
|
||||||
|
n_train = n_test = 500
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
loss_train = rng.randn(n_train) - 0.4
|
||||||
|
loss_test = rng.randn(n_test) + 0.4
|
||||||
|
logits_train = rng.randn(n_train, 5) + 0.2
|
||||||
|
logits_test = rng.randn(n_test, 5) - 0.2
|
||||||
|
labels_train = np.array([i % 5 for i in range(n_train)])
|
||||||
|
labels_test = np.array([(3 * i) % 5 for i in range(n_test)])
|
||||||
|
return (loss_train, loss_test, logits_train, logits_test,
|
||||||
|
labels_train, labels_test)
|
||||||
|
|
||||||
|
|
||||||
|
class GetVulnerabilityTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_get_vulnerabilities(self):
|
||||||
|
"""Test extraction of vulnerability scores."""
|
||||||
|
testdict = get_result_dict()
|
||||||
|
for key in ['auc', 'advantage']:
|
||||||
|
res = mia._get_vulnerabilities(testdict, key)
|
||||||
|
self.assertLen(res, 2)
|
||||||
|
self.assertEqual(res[f'test_{key}'], 1)
|
||||||
|
self.assertEqual(res['test_n_examples'], 1)
|
||||||
|
|
||||||
|
res = mia._get_vulnerabilities(testdict, ['auc', 'advantage'])
|
||||||
|
self.assertLen(res, 3)
|
||||||
|
self.assertEqual(res['test_auc'], 1)
|
||||||
|
self.assertEqual(res['test_advantage'], 1)
|
||||||
|
self.assertEqual(res['test_n_examples'], 1)
|
||||||
|
|
||||||
|
|
||||||
|
class GetMaximumVulnerabilityTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_get_maximum_vulnerability(self):
|
||||||
|
"""Test extraction of maximum vulnerability score."""
|
||||||
|
testdict = get_result_dict()
|
||||||
|
for i in range(3):
|
||||||
|
key = f'{i}-metric'
|
||||||
|
res = mia._get_maximum_vulnerability(testdict, key)
|
||||||
|
self.assertLen(res, 1)
|
||||||
|
self.assertEqual(res[key]['value'], i + 1)
|
||||||
|
if i < 2:
|
||||||
|
self.assertEqual(res[key]['attacker'], f'all_{i}-metric')
|
||||||
|
else:
|
||||||
|
self.assertEqual(res[key]['attacker'], 'test_2-metric')
|
||||||
|
|
||||||
|
res = mia._get_maximum_vulnerability(testdict, 'metric')
|
||||||
|
self.assertLen(res, 1)
|
||||||
|
self.assertEqual(res['metric']['value'], 3)
|
||||||
|
|
||||||
|
res = mia._get_maximum_vulnerability(testdict, ['metric'],
|
||||||
|
filterby='all')
|
||||||
|
self.assertLen(res, 1)
|
||||||
|
self.assertEqual(res['all-metric']['value'], 2)
|
||||||
|
|
||||||
|
res = mia._get_maximum_vulnerability(testdict, ['metric', 'score'])
|
||||||
|
self.assertLen(res, 2)
|
||||||
|
self.assertEqual(res['metric']['value'], 3)
|
||||||
|
self.assertEqual(res['score']['value'], 4)
|
||||||
|
self.assertEqual(res['score']['attacker'], 'test_score')
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdAttackLossTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_threshold_attack_loss(self):
|
||||||
|
"""Test simple threshold attack on loss."""
|
||||||
|
features = {
|
||||||
|
'loss': np.zeros(10),
|
||||||
|
'is_train': np.concatenate((np.zeros(5), np.ones(5))),
|
||||||
|
}
|
||||||
|
res = mia._run_threshold_loss_attack(features)
|
||||||
|
for k in res:
|
||||||
|
self.assertStartsWith(k, 'thresh_loss')
|
||||||
|
self.assertEqual(res['thresh_loss_auc'], 0.5)
|
||||||
|
self.assertEqual(res['thresh_loss_advantage'], 0.0)
|
||||||
|
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
n_train = 1000
|
||||||
|
n_test = 500
|
||||||
|
loss_train = rng.randn(n_train) - 0.4
|
||||||
|
loss_test = rng.randn(n_test) + 0.4
|
||||||
|
features = {
|
||||||
|
'loss': np.concatenate((loss_train, loss_test)),
|
||||||
|
'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))),
|
||||||
|
}
|
||||||
|
res = mia._run_threshold_loss_attack(features)
|
||||||
|
self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75)
|
||||||
|
self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35)
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdAttackMaxlogitTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_threshold_attack_maxlogits(self):
|
||||||
|
"""Test simple threshold attack on maximum logit."""
|
||||||
|
features = {
|
||||||
|
'logits': np.eye(10, 14),
|
||||||
|
'is_train': np.concatenate((np.zeros(5), np.ones(5))),
|
||||||
|
}
|
||||||
|
res = mia._run_threshold_attack_maxlogit(features)
|
||||||
|
for k in res:
|
||||||
|
self.assertStartsWith(k, 'thresh_maxlogit')
|
||||||
|
self.assertEqual(res['thresh_maxlogit_auc'], 0.5)
|
||||||
|
self.assertEqual(res['thresh_maxlogit_advantage'], 0.0)
|
||||||
|
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
n_train = 1000
|
||||||
|
n_test = 500
|
||||||
|
logits_train = rng.randn(n_train, 12) + 0.2
|
||||||
|
logits_test = rng.randn(n_test, 12) - 0.2
|
||||||
|
features = {
|
||||||
|
'logits': np.concatenate((logits_train, logits_test), axis=0),
|
||||||
|
'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))),
|
||||||
|
}
|
||||||
|
res = mia._run_threshold_attack_maxlogit(features)
|
||||||
|
self.assertBetween(res['thresh_maxlogit_auc'], 0.7, 0.75)
|
||||||
|
self.assertBetween(res['thresh_maxlogit_advantage'], 0.3, 0.35)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainedAttackTrivialTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_trained_attack(self):
|
||||||
|
"""Test trained attacks."""
|
||||||
|
# Trivially easy problem
|
||||||
|
x_train, x_test = np.ones((500, 3)), np.ones((20, 3))
|
||||||
|
x_train[:200] *= -1
|
||||||
|
x_test[:8] *= -1
|
||||||
|
y_train, y_test = np.ones(500).astype(int), np.ones(20).astype(int)
|
||||||
|
y_train[:200] = 0
|
||||||
|
y_test[:8] = 0
|
||||||
|
data = (x_train, y_train), (x_test, y_test)
|
||||||
|
for clf in ['lr', 'rf', 'mlp', 'knn']:
|
||||||
|
res = mia._run_trained_attack(clf, data, attack_prefix='a-')
|
||||||
|
self.assertEqual(res['a-train_auc'], 1)
|
||||||
|
self.assertEqual(res['a-test_auc'], 1)
|
||||||
|
self.assertEqual(res['a-train_advantage'], 1)
|
||||||
|
self.assertEqual(res['a-test_advantage'], 1)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainedAttackRandomFeaturesTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_trained_attack(self):
|
||||||
|
"""Test trained attacks."""
|
||||||
|
# Random labels and features
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
x_train, x_test = rng.randn(500, 3), rng.randn(500, 3)
|
||||||
|
y_train = rng.binomial(1, 0.5, size=(500,))
|
||||||
|
y_test = rng.binomial(1, 0.5, size=(500,))
|
||||||
|
data = (x_train, y_train), (x_test, y_test)
|
||||||
|
for clf in ['lr', 'rf', 'mlp', 'knn']:
|
||||||
|
res = mia._run_trained_attack(clf, data, attack_prefix='a-')
|
||||||
|
self.assertBetween(res['a-train_auc'], 0.5, 1.)
|
||||||
|
self.assertBetween(res['a-test_auc'], 0.4, 0.6)
|
||||||
|
self.assertBetween(res['a-train_advantage'], 0., 1.0)
|
||||||
|
self.assertBetween(res['a-test_advantage'], 0., 0.2)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackLossesTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_attack(self):
|
||||||
|
"""Test individual attack function."""
|
||||||
|
# losses only, both metrics
|
||||||
|
loss_train, loss_test, _, _, _, _ = get_test_inputs()
|
||||||
|
res = mia.run_attack(loss_train, loss_test, metric=('auc', 'advantage'))
|
||||||
|
self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75)
|
||||||
|
self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackLossesLogitsTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_attack(self):
|
||||||
|
"""Test individual attack function."""
|
||||||
|
# losses and logits, two classifiers, single metric
|
||||||
|
loss_train, loss_test, logits_train, logits_test, _, _ = get_test_inputs()
|
||||||
|
res = mia.run_attack(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
logits_train,
|
||||||
|
logits_test,
|
||||||
|
attack_classifiers=('rf', 'knn'),
|
||||||
|
metric='auc')
|
||||||
|
self.assertBetween(res['rf_logits_test_auc'], 0.7, 0.9)
|
||||||
|
self.assertBetween(res['knn_logits_test_auc'], 0.7, 0.9)
|
||||||
|
self.assertBetween(res['rf_logits_loss_test_auc'], 0.7, 0.9)
|
||||||
|
self.assertBetween(res['knn_logits_loss_test_auc'], 0.7, 0.9)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackLossesLabelsByClassTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_attack(self):
|
||||||
|
# losses and labels, single metric, split by class
|
||||||
|
loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs()
|
||||||
|
n_train = loss_train.shape[0]
|
||||||
|
n_test = loss_test.shape[0]
|
||||||
|
res = mia.run_attack(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
labels_train=labels_train,
|
||||||
|
labels_test=labels_test,
|
||||||
|
by_class=True,
|
||||||
|
metric='auc')
|
||||||
|
self.assertLen(res, 10)
|
||||||
|
for k in res:
|
||||||
|
self.assertStartsWith(k, 'class_')
|
||||||
|
if k.endswith('n_examples'):
|
||||||
|
self.assertEqual(int(res[k]), (n_train + n_test) // 5)
|
||||||
|
else:
|
||||||
|
self.assertBetween(res[k], 0.65, 0.75)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackLossesLabelsSingleClassTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_attack(self):
|
||||||
|
# losses and labels, both metrics, single class
|
||||||
|
loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs()
|
||||||
|
n_train = loss_train.shape[0]
|
||||||
|
n_test = loss_test.shape[0]
|
||||||
|
res = mia.run_attack(
|
||||||
|
loss_train,
|
||||||
|
loss_test,
|
||||||
|
labels_train=labels_train,
|
||||||
|
labels_test=labels_test,
|
||||||
|
by_class=2,
|
||||||
|
metric=('auc', 'advantage'))
|
||||||
|
self.assertLen(res, 3)
|
||||||
|
for k in res:
|
||||||
|
self.assertStartsWith(k, 'class_2')
|
||||||
|
if k.endswith('n_examples'):
|
||||||
|
self.assertEqual(int(res[k]), (n_train + n_test) // 5)
|
||||||
|
elif k.endswith('advantage'):
|
||||||
|
self.assertBetween(res[k], 0.3, 0.5)
|
||||||
|
elif k.endswith('auc'):
|
||||||
|
self.assertBetween(res[k], 0.7, 0.75)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackLogitsLabelsMisclassifiedTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_attack(self):
|
||||||
|
# logits and labels, single metric, single classifier, misclassified only
|
||||||
|
(_, _, logits_train, logits_test,
|
||||||
|
labels_train, labels_test) = get_test_inputs()
|
||||||
|
res = mia.run_attack(
|
||||||
|
logits_train=logits_train,
|
||||||
|
logits_test=logits_test,
|
||||||
|
labels_train=labels_train,
|
||||||
|
labels_test=labels_test,
|
||||||
|
only_misclassified=True,
|
||||||
|
attack_classifiers=('lr',),
|
||||||
|
metric='advantage')
|
||||||
|
self.assertBetween(res['misclassified_lr_logits_test_advantage'], 0.3, 0.8)
|
||||||
|
self.assertEqual(res['misclassified_n_examples'], 802)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackLogitsByPrecentileTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_attack(self):
|
||||||
|
# only logits, single metric, no classifiers, split by deciles
|
||||||
|
_, _, logits_train, logits_test, _, _ = get_test_inputs()
|
||||||
|
res = mia.run_attack(
|
||||||
|
logits_train=logits_train,
|
||||||
|
logits_test=logits_test,
|
||||||
|
by_percentile=True,
|
||||||
|
metric='auc')
|
||||||
|
for k in res:
|
||||||
|
self.assertStartsWith(k, 'percentile')
|
||||||
|
self.assertBetween(res[k], 0.60, 0.75)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Plotting functionality for membership inference attack analysis.
|
||||||
|
|
||||||
|
Functions to plot ROC curves and histograms as well as functionality to store
|
||||||
|
figures to colossus.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Text, Iterable
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
|
||||||
|
def save_plot(figure: plt.Figure, path: Text, outformat='png'):
|
||||||
|
"""Store a figure to disk."""
|
||||||
|
if path is not None:
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
figure.savefig(f, bbox_inches='tight', format=outformat)
|
||||||
|
plt.close(figure)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_curve_with_area(x: Iterable[float],
|
||||||
|
y: Iterable[float],
|
||||||
|
xlabel: Text = 'x',
|
||||||
|
ylabel: Text = 'y') -> plt.Figure:
|
||||||
|
"""Plot the curve defined by inputs and the area under the curve.
|
||||||
|
|
||||||
|
All entries of x and y are required to lie between 0 and 1.
|
||||||
|
For example, x could be recall and y precision, or x is fpr and y is tpr.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Values on x-axis (1d)
|
||||||
|
y: Values on y-axis (must be same length as x)
|
||||||
|
xlabel: Label for x axis
|
||||||
|
ylabel: Label for y axis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The matplotlib figure handle
|
||||||
|
"""
|
||||||
|
fig = plt.figure()
|
||||||
|
plt.plot([0, 1], [0, 1], 'k', lw=1.0)
|
||||||
|
plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}')
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.ylabel(ylabel)
|
||||||
|
plt.legend()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_histograms(train: Iterable[float],
|
||||||
|
test: Iterable[float],
|
||||||
|
xlabel: Text = 'x',
|
||||||
|
thresh: float = None) -> plt.Figure:
|
||||||
|
"""Plot histograms of training versus test metrics."""
|
||||||
|
xmin = min(np.min(train), np.min(test))
|
||||||
|
xmax = max(np.max(train), np.max(test))
|
||||||
|
bins = np.linspace(xmin, xmax, 100)
|
||||||
|
fig = plt.figure()
|
||||||
|
plt.hist(test, bins=bins, density=True, alpha=0.5, label='test', log='y')
|
||||||
|
plt.hist(train, bins=bins, density=True, alpha=0.5, label='train', log='y')
|
||||||
|
if thresh is not None:
|
||||||
|
plt.axvline(thresh, c='r', label=f'threshold = {thresh:.3f}')
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.ylabel('normalized counts (density)')
|
||||||
|
plt.legend()
|
||||||
|
return fig
|
|
@ -0,0 +1,217 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
r"""This module contains code to run attacks on previous model outputs.
|
||||||
|
|
||||||
|
Provided a path to a dataset of model outputs (logits, output probabilities,
|
||||||
|
losses, labels, predictions, membership indicators (is train or not)), we train
|
||||||
|
supervised binary classifiers using variable sets of features to distinguish
|
||||||
|
training from testing examples. We also provide threshold attacks, i.e., simply
|
||||||
|
thresholing losses or the maximum probability/logit to obtain binary
|
||||||
|
predictions.
|
||||||
|
|
||||||
|
The input data is assumed to be a tf.example proto stored with RecordIO (.rio).
|
||||||
|
For example, outputs in an accepted format are typically produced by the
|
||||||
|
`extract` script in the `extract` directory.
|
||||||
|
|
||||||
|
We run various attacks on each of the full datasets, split by class, split by
|
||||||
|
percentile of the most certain prediction and only on misclassified examples and
|
||||||
|
record the area under the receiver operator curve as well as the attack
|
||||||
|
advantage (i.e., max |tpr - fpr|) as vulnerability metrics. For all metrics
|
||||||
|
recorded, see the doc string of `membership_inference_attack.all_attacks`.
|
||||||
|
In addition, we record the overall training and test accuracy and loss of the
|
||||||
|
original image classifier. All these results are collected in a single
|
||||||
|
dictionary with descriptive keys. If there exist multiple model checkpoints (at
|
||||||
|
different training epochs), the results for each checkpoint are concatenated,
|
||||||
|
such that the dictionary keys stay the same, but the values contain arrays (the
|
||||||
|
size being the number of checkpoints). This overall result dicitonary is then
|
||||||
|
stored as a binary (and compressed) numpy file: .npz.
|
||||||
|
This file is stored either in the provided output path. If that is the empty
|
||||||
|
string, it is stored on the same level as the inputdir with the chosen name.
|
||||||
|
Using `attack_results.npz` by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
|
||||||
|
python run_attack.py --dataset=cifar10 --inputdir="attack_data"
|
||||||
|
The results are then stored at ./attack_data
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
from typing import Text, Dict
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.google as tf
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import utils
|
||||||
|
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
Result = Dict[Text, np.ndarray]
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_float('test_size', 0.2,
|
||||||
|
'Fraction of attack data used for the test set.')
|
||||||
|
flags.DEFINE_string('dataset', 'cifar10', 'The dataset to use.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'output', '', 'The path where to store the results. '
|
||||||
|
'If empty string, store on same level as `inputdir` using '
|
||||||
|
'the name specified in the result_name flag.')
|
||||||
|
flags.DEFINE_string('result_name', 'attack_results.npz',
|
||||||
|
'The name of the output npz file with the attack results.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'inputdir',
|
||||||
|
'attack_data',
|
||||||
|
'The input directory containing the attack datasets.')
|
||||||
|
flags.DEFINE_integer('seed', 43, 'Random seed to ensure same data splits.')
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Load and select features for attacks
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def load_all_features(data_path: Text) -> Result:
|
||||||
|
"""Extract the selected features from a given dataset."""
|
||||||
|
if FLAGS.dataset == 'cifar100':
|
||||||
|
num_classes = 100
|
||||||
|
elif FLAGS.dataset in ['cifar10', 'mnist']:
|
||||||
|
num_classes = 10
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown dataset {FLAGS.dataset}')
|
||||||
|
|
||||||
|
features = {
|
||||||
|
'logits': tf.FixedLenFeature((num_classes,), tf.float32),
|
||||||
|
'prob': tf.FixedLenFeature((num_classes,), tf.float32),
|
||||||
|
'loss': tf.FixedLenFeature([], tf.float32),
|
||||||
|
'is_train': tf.FixedLenFeature([], tf.int64),
|
||||||
|
'label': tf.FixedLenFeature([], tf.int64),
|
||||||
|
'prediction': tf.FixedLenFeature([], tf.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = tf.data.RecordIODataset(data_path)
|
||||||
|
|
||||||
|
results = {k: [] for k in features}
|
||||||
|
ds = dataset.map(lambda x: tf.parse_single_example(x, features))
|
||||||
|
for example in tfds.as_numpy(ds):
|
||||||
|
for k in results:
|
||||||
|
results[k].append(example[k])
|
||||||
|
return utils.to_numpy(results)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Run attacks
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_attacks(data_path: Text):
|
||||||
|
"""Train all possible attacks on the data at the given path."""
|
||||||
|
logging.info('Load all features from %s...', data_path)
|
||||||
|
features = load_all_features(data_path)
|
||||||
|
|
||||||
|
for k, v in features.items():
|
||||||
|
logging.info('%s: %s', k, v.shape)
|
||||||
|
|
||||||
|
logging.info('Compute original train/test accuracy and loss...')
|
||||||
|
train_idx = features['is_train'] == 1
|
||||||
|
test_idx = np.logical_not(train_idx)
|
||||||
|
correct = features['label'] == features['prediction']
|
||||||
|
result = {
|
||||||
|
'original_train_loss': np.mean(features['loss'][train_idx]),
|
||||||
|
'original_test_loss': np.mean(features['loss'][test_idx]),
|
||||||
|
'original_train_acc': np.mean(correct[train_idx]),
|
||||||
|
'original_test_acc': np.mean(correct[test_idx]),
|
||||||
|
}
|
||||||
|
|
||||||
|
result.update(
|
||||||
|
mia.run_all_attacks(
|
||||||
|
loss_train=features['loss'][train_idx],
|
||||||
|
loss_test=features['loss'][test_idx],
|
||||||
|
logits_train=features['logits'][train_idx],
|
||||||
|
logits_test=features['logits'][test_idx],
|
||||||
|
labels_train=features['label'][train_idx],
|
||||||
|
labels_test=features['label'][test_idx],
|
||||||
|
attack_classifiers=('lr', 'mlp', 'rf', 'knn'),
|
||||||
|
decimals=None))
|
||||||
|
result = utils.ensure_1d(result)
|
||||||
|
|
||||||
|
logging.info('Finished training and evaluating attacks.')
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def attacking():
|
||||||
|
"""Load data and model and extract relevant outputs."""
|
||||||
|
# ---------- Set result path ----------
|
||||||
|
if FLAGS.output:
|
||||||
|
resultpath = FLAGS.output
|
||||||
|
else:
|
||||||
|
resultdir = FLAGS.inputdir
|
||||||
|
if resultdir[-1] == '/':
|
||||||
|
resultdir = resultdir[:-1]
|
||||||
|
resultdir = '/'.join(resultdir.split('/')[:-1])
|
||||||
|
resultpath = os.path.join(resultdir, FLAGS.result_name)
|
||||||
|
|
||||||
|
# ---------- Glob attack training sets ----------
|
||||||
|
logging.info('Glob attack data paths...')
|
||||||
|
data_paths = sorted(glob(os.path.join(FLAGS.inputdir, '*')))
|
||||||
|
logging.info('Found %d data paths', len(data_paths))
|
||||||
|
|
||||||
|
# ---------- Iterate over attack dataset and train attacks ----------
|
||||||
|
epochs = []
|
||||||
|
results = []
|
||||||
|
for i, datapath in enumerate(data_paths):
|
||||||
|
logging.info('=' * 80)
|
||||||
|
logging.info('Attack model %d / %d', i + 1, len(data_paths))
|
||||||
|
logging.info('=' * 80)
|
||||||
|
basename = os.path.basename(datapath)
|
||||||
|
found_ints = re.findall(r'(\d+)', basename)
|
||||||
|
if len(found_ints) == 1:
|
||||||
|
epoch = int(found_ints[0])
|
||||||
|
logging.info('Found integer %d in pathname, interpret as epoch', epoch)
|
||||||
|
else:
|
||||||
|
epoch = np.nan
|
||||||
|
tmp_res = run_all_attacks(datapath)
|
||||||
|
if tmp_res is not None:
|
||||||
|
results.append(tmp_res)
|
||||||
|
epochs.append(epoch)
|
||||||
|
|
||||||
|
# ---------- Aggregate and save results ----------
|
||||||
|
logging.info('Aggregate and combine all results over epochs...')
|
||||||
|
results = utils.merge_dictionaries(results)
|
||||||
|
results['epochs'] = np.array(epochs)
|
||||||
|
logging.info('Store aggregate results at %s.', resultpath)
|
||||||
|
with open(resultpath, 'wb') as fp:
|
||||||
|
io_buffer = io.BytesIO()
|
||||||
|
np.savez(io_buffer, **results)
|
||||||
|
fp.write(io_buffer.getvalue())
|
||||||
|
|
||||||
|
logging.info('Finished attacks.')
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv # Unused
|
||||||
|
attacking()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
r"""A collection of sklearn models for binary classification.
|
||||||
|
|
||||||
|
This module contains some sklearn pipelines for finding models for binary
|
||||||
|
classification from a variable number of numerical input features.
|
||||||
|
These models are used to train binary classifiers for membership inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Text
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn import ensemble
|
||||||
|
from sklearn import linear_model
|
||||||
|
from sklearn import model_selection
|
||||||
|
from sklearn import neighbors
|
||||||
|
from sklearn import neural_network
|
||||||
|
|
||||||
|
|
||||||
|
def choose_model(attack_classifier: Text):
|
||||||
|
"""Choose a model based on a string classifier."""
|
||||||
|
if attack_classifier == 'lr':
|
||||||
|
return logistic_regression()
|
||||||
|
elif attack_classifier == 'mlp':
|
||||||
|
return mlp()
|
||||||
|
elif attack_classifier == 'rf':
|
||||||
|
return random_forest()
|
||||||
|
elif attack_classifier == 'knn':
|
||||||
|
return knn()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown attack classifier {attack_classifier}.')
|
||||||
|
|
||||||
|
|
||||||
|
def logistic_regression(verbose: int = 0, n_jobs: int = 1):
|
||||||
|
"""Setup a logistic regression pipeline with cross-validation."""
|
||||||
|
lr = linear_model.LogisticRegression(solver='lbfgs')
|
||||||
|
param_grid = {
|
||||||
|
'C': np.logspace(-4, 2, 10),
|
||||||
|
}
|
||||||
|
pipe = model_selection.GridSearchCV(
|
||||||
|
lr, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
|
||||||
|
verbose=verbose)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def random_forest(verbose: int = 0, n_jobs: int = 1):
|
||||||
|
"""Setup a random forest pipeline with cross-validation."""
|
||||||
|
rf = ensemble.RandomForestClassifier()
|
||||||
|
|
||||||
|
n_estimators = [100]
|
||||||
|
max_features = ['auto', 'sqrt']
|
||||||
|
max_depth = [5, 10, 20]
|
||||||
|
max_depth.append(None)
|
||||||
|
min_samples_split = [2, 5, 10]
|
||||||
|
min_samples_leaf = [1, 2, 4]
|
||||||
|
random_grid = {'n_estimators': n_estimators,
|
||||||
|
'max_features': max_features,
|
||||||
|
'max_depth': max_depth,
|
||||||
|
'min_samples_split': min_samples_split,
|
||||||
|
'min_samples_leaf': min_samples_leaf}
|
||||||
|
|
||||||
|
pipe = model_selection.RandomizedSearchCV(
|
||||||
|
rf, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs,
|
||||||
|
iid=False, verbose=verbose)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def mlp(verbose: int = 0, n_jobs: int = 1):
|
||||||
|
"""Setup a MLP pipeline with cross-validation."""
|
||||||
|
mlpmodel = neural_network.MLPClassifier()
|
||||||
|
|
||||||
|
param_grid = {
|
||||||
|
'hidden_layer_sizes': [(64,), (32, 32)],
|
||||||
|
'solver': ['adam'],
|
||||||
|
'alpha': [0.0001, 0.001, 0.01],
|
||||||
|
}
|
||||||
|
pipe = model_selection.GridSearchCV(
|
||||||
|
mlpmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
|
||||||
|
verbose=verbose)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def knn(verbose: int = 0, n_jobs: int = 1):
|
||||||
|
"""Setup a k-nearest neighbors pipeline with cross-validation."""
|
||||||
|
knnmodel = neighbors.KNeighborsClassifier()
|
||||||
|
|
||||||
|
param_grid = {
|
||||||
|
'n_neighbors': [3, 5, 7],
|
||||||
|
}
|
||||||
|
pipe = model_selection.GridSearchCV(
|
||||||
|
knnmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
|
||||||
|
verbose=verbose)
|
||||||
|
return pipe
|
218
tensorflow_privacy/privacy/membership_inference_attack/utils.py
Normal file
218
tensorflow_privacy/privacy/membership_inference_attack/utils.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Utility functions for membership inference attacks."""
|
||||||
|
|
||||||
|
from typing import Text, Dict, Union, List, Any, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
ArrayDict = Dict[Text, np.ndarray]
|
||||||
|
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Utilities for managing result dictionaries
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy(in_dict: Dict[Text, Any]) -> ArrayDict:
|
||||||
|
"""Convert values of dict to numpy arrays.
|
||||||
|
|
||||||
|
Warning: This may fail if the values cannot be converted to numpy arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_dict: A dictionary mapping Text keys to values where the values must be
|
||||||
|
something that can be converted to a numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dictionary with the same keys as input with all values converted to numpy
|
||||||
|
arrays
|
||||||
|
"""
|
||||||
|
return {k: np.array(v) for k, v in in_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_1d(in_dict: Dict[Text, Union[int, float, np.ndarray]]) -> ArrayDict:
|
||||||
|
"""Ensure all values of a dictionary are at least 1D numpy arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_dict: The input dictionary mapping Text keys to numpy arrays or numbers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary with same keys as in_dict and values converted to numpy arrays
|
||||||
|
with at least one dimension (i.e., pack scalars into arrays)
|
||||||
|
"""
|
||||||
|
return {k: np.atleast_1d(v) for k, v in in_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
|
||||||
|
"""Prepend a prefix to all keys of a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_dict: The input dictionary mapping Text keys to numpy arrays.
|
||||||
|
prefix: Text which to prepend to each key in in_dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary with same values as in_dict and all keys having prefix prepended
|
||||||
|
to them
|
||||||
|
"""
|
||||||
|
return {prefix + k: v for k, v in in_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Subsampling and data selection functionality
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def select_indices(in_dict: ArrayDict, indices: np.ndarray) -> ArrayDict:
|
||||||
|
"""Subsample all values in the dictionary by the provided indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_dict: The input dictionary mapping Text keys to numpy array values.
|
||||||
|
indices: A numpy which can be used to index other arrays, specifying the
|
||||||
|
indices to subsample from in_dict values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary with same keys as in_dict and subsampled values
|
||||||
|
"""
|
||||||
|
return {k: v[indices] for k, v in in_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dictionaries(res: List[ArrayDict]) -> ArrayDict:
|
||||||
|
"""Convert iterable of dicts to dict of iterables."""
|
||||||
|
output = {k: np.empty(0) for k in res[0]}
|
||||||
|
for k in output:
|
||||||
|
output[k] = np.concatenate([r[k] for r in res if k in r], axis=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_features(features: ArrayDict,
|
||||||
|
feature_name: Text,
|
||||||
|
top_k: int,
|
||||||
|
add_loss: bool = False) -> np.ndarray:
|
||||||
|
"""Combine the specified features into one array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: A dictionary containing all possible features.
|
||||||
|
feature_name: Which feature to use (logits or prob).
|
||||||
|
top_k: The number of the top features (of feature_name) to select.
|
||||||
|
add_loss: Whether to also add the loss as a feature.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
combined numpy array with the selected features (n_examples, n_features)
|
||||||
|
"""
|
||||||
|
if top_k < 1:
|
||||||
|
raise ValueError('Must select at least one feature.')
|
||||||
|
feats = np.sort(features[feature_name], axis=-1)[:, :top_k]
|
||||||
|
if add_loss:
|
||||||
|
feats = np.concatenate((feats, features['loss'][:, np.newaxis]), axis=-1)
|
||||||
|
return feats
|
||||||
|
|
||||||
|
|
||||||
|
def subsample_to_balance(features: ArrayDict, random_state: int) -> ArrayDict:
|
||||||
|
"""Subsample if necessary to balance labels."""
|
||||||
|
train_idx = features['is_train'] == 1
|
||||||
|
test_idx = np.logical_not(train_idx)
|
||||||
|
n0 = np.sum(test_idx)
|
||||||
|
n1 = np.sum(train_idx)
|
||||||
|
|
||||||
|
if n0 < 20 or n1 < 20:
|
||||||
|
raise RuntimeError('Need at least 20 examples from training and test set.')
|
||||||
|
|
||||||
|
np.random.seed(random_state)
|
||||||
|
|
||||||
|
if n0 > n1:
|
||||||
|
use_idx = np.random.choice(np.where(test_idx)[0], n1, replace=False)
|
||||||
|
use_idx = np.concatenate((use_idx, np.where(train_idx)[0]))
|
||||||
|
features = {k: v[use_idx] for k, v in features.items()}
|
||||||
|
elif n0 < n1:
|
||||||
|
use_idx = np.random.choice(np.where(train_idx)[0], n0, replace=False)
|
||||||
|
use_idx = np.concatenate((use_idx, np.where(test_idx)[0]))
|
||||||
|
features = {k: v[use_idx] for k, v in features.items()}
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_test_split(features: ArrayDict, add_loss: bool,
|
||||||
|
test_size: float) -> Dataset:
|
||||||
|
"""Get training and test data split."""
|
||||||
|
y = features['is_train']
|
||||||
|
n_total = len(y)
|
||||||
|
n_test = int(test_size * n_total)
|
||||||
|
perm = np.random.permutation(len(y))
|
||||||
|
test_idx = perm[:n_test]
|
||||||
|
train_idx = perm[n_test:]
|
||||||
|
y_train = y[train_idx]
|
||||||
|
y_test = y[test_idx]
|
||||||
|
|
||||||
|
# We are using 10 top logits as a good default value if there are more than 10
|
||||||
|
# classes. Typically, there is no significant amount of weight in more than
|
||||||
|
# 10 logits.
|
||||||
|
n_logits = min(features['logits'].shape[1], 10)
|
||||||
|
x = get_features(features, 'logits', n_logits, add_loss)
|
||||||
|
|
||||||
|
x_train, x_test = x[train_idx], x[test_idx]
|
||||||
|
return (x_train, y_train), (x_test, y_test)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Computation of the attack metrics
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def compute_performance_metrics(true_labels: np.ndarray,
|
||||||
|
predictions: np.ndarray,
|
||||||
|
threshold: float = None) -> ArrayDict:
|
||||||
|
"""Compute relevant classification performance metrics.
|
||||||
|
|
||||||
|
The outout metrics are
|
||||||
|
1.arrays of thresholds and corresponding true and false positives (fpr, tpr).
|
||||||
|
2.auc area under fpr-tpr curve.
|
||||||
|
3.advantage max difference between tpr and fpr.
|
||||||
|
4.precision/recall/accuracy/f1_score if threshold arg is given.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
true_labels: True labels.
|
||||||
|
predictions: Predicted probabilities/scores.
|
||||||
|
threshold: The threshold to use on `predictions` binary classification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with relevant metrics which are fully described by their key.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
if threshold is not None:
|
||||||
|
results.update({
|
||||||
|
'precision':
|
||||||
|
metrics.precision_score(true_labels, predictions > threshold),
|
||||||
|
'recall':
|
||||||
|
metrics.recall_score(true_labels, predictions > threshold),
|
||||||
|
'accuracy':
|
||||||
|
metrics.accuracy_score(true_labels, predictions > threshold),
|
||||||
|
'f1_score':
|
||||||
|
metrics.f1_score(true_labels, predictions > threshold),
|
||||||
|
})
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(true_labels, predictions)
|
||||||
|
auc = metrics.auc(fpr, tpr)
|
||||||
|
advantage = np.max(np.abs(tpr - fpr))
|
||||||
|
|
||||||
|
results.update({
|
||||||
|
'fpr': fpr,
|
||||||
|
'tpr': tpr,
|
||||||
|
'thresholds': thresholds,
|
||||||
|
'auc': auc,
|
||||||
|
'advantage': advantage,
|
||||||
|
})
|
||||||
|
return ensure_1d(results)
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
|
||||||
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import utils
|
||||||
|
|
||||||
|
|
||||||
|
class UtilsTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, methodname):
|
||||||
|
"""Initialize the test class."""
|
||||||
|
super().__init__(methodname)
|
||||||
|
rng = np.random.RandomState(33)
|
||||||
|
logits = rng.uniform(low=0, high=1, size=(1000, 14))
|
||||||
|
loss = rng.uniform(low=0, high=1, size=(1000,))
|
||||||
|
is_train = rng.binomial(1, 0.7, size=(1000,))
|
||||||
|
self.mydict = {'logits': logits, 'loss': loss, 'is_train': is_train}
|
||||||
|
|
||||||
|
def test_compute_metrics(self):
|
||||||
|
"""Test computation of attack metrics."""
|
||||||
|
true = np.array([0, 0, 0, 1, 1, 1])
|
||||||
|
pred = np.array([0.6, 0.9, 0.4, 0.8, 0.7, 0.2])
|
||||||
|
|
||||||
|
results = utils.compute_performance_metrics(true, pred, threshold=0.5)
|
||||||
|
|
||||||
|
for k in ['precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr',
|
||||||
|
'thresholds', 'auc', 'advantage']:
|
||||||
|
self.assertIn(k, results)
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(results['accuracy'], 1. / 2.)
|
||||||
|
np.testing.assert_almost_equal(results['precision'], 2. / (2. + 2.))
|
||||||
|
np.testing.assert_almost_equal(results['recall'], 2. / (2. + 1.))
|
||||||
|
|
||||||
|
def test_prepend_to_keys(self):
|
||||||
|
"""Test prepending of text to keys of a dictionary."""
|
||||||
|
mydict = utils.prepend_to_keys(self.mydict, 'test')
|
||||||
|
for k in mydict:
|
||||||
|
self.assertTrue(k.startswith('test'))
|
||||||
|
|
||||||
|
def test_select_indices(self):
|
||||||
|
"""Test selecting indices from dictionary with array values."""
|
||||||
|
mydict = {'a': np.arange(10), 'b': np.linspace(0, 1, 10)}
|
||||||
|
|
||||||
|
idx = np.arange(5)
|
||||||
|
mydictidx = utils.select_indices(mydict, idx)
|
||||||
|
np.testing.assert_allclose(mydictidx['a'], np.arange(5))
|
||||||
|
np.testing.assert_allclose(mydictidx['b'], np.linspace(0, 1, 10)[:5])
|
||||||
|
|
||||||
|
idx = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) > 0.5
|
||||||
|
mydictidx = utils.select_indices(mydict, idx)
|
||||||
|
np.testing.assert_allclose(mydictidx['a'], np.arange(0, 10, 2))
|
||||||
|
np.testing.assert_allclose(mydictidx['b'], np.linspace(0, 1, 10)[0:10:2])
|
||||||
|
|
||||||
|
def test_get_features(self):
|
||||||
|
"""Test extraction of features."""
|
||||||
|
for k in [1, 5, 10, 15]:
|
||||||
|
for add_loss in [True, False]:
|
||||||
|
feats = utils.get_features(
|
||||||
|
self.mydict, 'logits', top_k=k, add_loss=add_loss)
|
||||||
|
k_selected = min(k, 14)
|
||||||
|
self.assertEqual(feats.shape, (1000, k_selected + int(add_loss)))
|
||||||
|
|
||||||
|
def test_subsample_to_balance(self):
|
||||||
|
"""Test subsampling of two arrays."""
|
||||||
|
feats = utils.subsample_to_balance(self.mydict, random_state=23)
|
||||||
|
|
||||||
|
train = np.sum(self.mydict['is_train'])
|
||||||
|
test = 1000 - train
|
||||||
|
n_chosen = min(train, test)
|
||||||
|
self.assertEqual(feats['logits'].shape, (2 * n_chosen, 14))
|
||||||
|
self.assertEqual(feats['loss'].shape, (2 * n_chosen,))
|
||||||
|
self.assertEqual(np.sum(feats['is_train']), n_chosen)
|
||||||
|
self.assertEqual(np.sum(1 - feats['is_train']), n_chosen)
|
||||||
|
|
||||||
|
def test_get_data(self):
|
||||||
|
"""Test train test split data generation."""
|
||||||
|
for test_size in [0.2, 0.5, 0.8, 0.55555]:
|
||||||
|
(x_train, y_train), (x_test, y_test) = utils.get_train_test_split(
|
||||||
|
self.mydict, add_loss=True, test_size=test_size)
|
||||||
|
n_test = int(test_size * 1000)
|
||||||
|
n_train = 1000 - n_test
|
||||||
|
self.assertEqual(x_train.shape, (n_train, 11))
|
||||||
|
self.assertEqual(y_train.shape, (n_train,))
|
||||||
|
self.assertEqual(x_test.shape, (n_test, 11))
|
||||||
|
self.assertEqual(y_test.shape, (n_test,))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
Loading…
Reference in a new issue