Open sourcing membership inference attack.

PiperOrigin-RevId: 317958055
This commit is contained in:
A. Unique TensorFlower 2020-06-23 16:11:40 -07:00
parent 1fb9b80d90
commit 88dd8771bf
9 changed files with 2000 additions and 0 deletions

View file

@ -0,0 +1,238 @@
# Membership inference attack functionality
The goal is to provide empirical tests of "how much information a machine
learning model has remembered about its training data". To this end, only the
outputs of the model are used (e.g., losses, logits, predictions). From those
alone, the attacks try to infer whether the corresponding inputs were part of
the training set.
> NOTE: Only the loss values are needed for some examples used during training
> and some examples that have not been used during training (e.g., some examples
> from the test set). No access to actual input data is needed. In case of
> classification models, one can additionally (or instead of losses) provide
> logits or output probabilities for stronger attacks.
The vulnerability of a model is measured via the area under the ROC-curve
(`auc`) or via max{|fpr - tpr|} (`advantage`) of the attack classifier. These
measures are very closely related.
## Highest level -- get attack summary
### Basic usage
On the highest level, there is the `run_all_attacks_and_create_summary`
function, which chooses sane default options to run a host of (fairly simple)
attacks behind the scenes (depending on which data is fed in), computes the most
important measures and returns a summary of the results as a string of english
language (as well as optionally a python dictionary containing all results with
descriptive keys).
> NOTE: The train and test sets are balanced internally, i.e., an equal number
> of in-training and out-of-training examples is chosen for the attacks
> (whichever has fewer examples). These are subsampled uniformly at random
> without replacement from the larger of the two.
The simplest possible usage is
```python
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
# Evaluate your model on training and test examples to get
# loss_train shape: (n_train, )
# loss_test shape: (n_test, )
summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, return_dict=True)
print(results)
# -> {'auc': 0.7044,
# 'best_attacker_auc': 'all_thresh_loss_auc',
# 'advantage': 0.3116,
# 'best_attacker_auc': 'all_thresh_loss_advantage'}
```
> NOTE: The keyword argument `return_dict` specified whether in addition to the
> `summary` the function also returns a python dictionary with the results.
If the model is a classifier, the logits or output probabilities (i.e., the
softmax of logits) can also be provided to perform stronger attacks.
> NOTE: The `logits_train` and `logits_test` arguments can also be filled with
> output probabilities per class ("posteriors").
```python
# logits_train shape: (n_train, n_classes)
# logits_test shape: (n_test, n_classes)
summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train,
logits_test, return_dict=True)
print(results)
# -> {'auc': 0.5382,
# 'best_attacker_auc': 'all_lr_logits_loss_test_auc',
# 'advantage': 0.0572,
# 'best_attacker_auc': 'all_mlp_logits_loss_test_advantage'}
```
The `summary` will be a string in natural language describing the results in
more detail, e.g.,
```
========== AUC ==========
The best attack (all_lr_logits_loss_test_auc) achieved an auc of 0.5382.
========== ADVANTAGE ==========
The best attack (all_mlp_logits_loss_test_advantage) achieved an advantage of 0.0572.
```
Similarly, we can run attacks on the logits alone, without access to losses:
```python
summary, results = mia.run_all_attacks_and_create_summary(logits_train=logits_train,
logits_test=logits_test,
return_dict=True)
print(results)
# -> {'auc': 0.9278,
# 'best_attacker_auc': 'all_rf_logits_test_auc',
# 'advantage': 0.6991,
# 'best_attacker_auc': 'all_rf_logits_test_advantage'}
```
### Advanced usage
Finally, if we also have access to the true labels of the training and test
inputs, we can run the attacks for each class separately. If labels *and* logits
are provided, attacks only for misclassified (typically uncertain) examples are
also performed.
```python
summary, results = mia.run_all_attacks_and_create_summary(loss_train, loss_test, logits_train,
logits_test, labels_train, labels_test,
return_dict=True)
```
Here, we now also get as output the class with the maximal vulnerability
according to our metrics (`max_vuln_class_auc`, `max_vuln_class_advantage`)
together with the corresponding values (`class_<CLASS>_auc`,
`class_<CLASS>_advantage`). The same values exist in the `results` dictionary
for `min` instead of `max`, i.e., the least vulnerable classes. Moreover, the
gap between the maximum and minimum values (`max_class_gap_auc`,
`max_class_gap_advantage`) is also provided. Similarly, the vulnerability
metrics when the attacks are restricted to the misclassified examples
(`misclassified_auc`, `misclassified_advantage`) are also shown. Finally, the
results also contain the number of examples in each of these groups, i.e.,
within each of the reported classes as well as the number of misclassified
examples. The final `results` dictionary is of the form
```
{'auc': 0.9181,
'best_attacker_auc': 'all_rf_logits_loss_test_auc',
'advantage': 0.6915,
'best_attacker_advantage': 'all_rf_logits_loss_test_advantage',
'max_class_gap_auc': 0.254,
'class_5_auc': 0.9512,
'class_3_auc': 0.6972,
'max_vuln_class_auc': 5,
'min_vuln_class_auc': 3,
'max_class_gap_advantage': 0.5073,
'class_0_advantage': 0.8086,
'class_3_advantage': 0.3013,
'max_vuln_class_advantage': 0,
'min_vuln_class_advantage': 3,
'misclassified_n_examples': 4513.0,
'class_0_n_examples': 899.0,
'class_1_n_examples': 900.0,
'class_2_n_examples': 931.0,
'class_3_n_examples': 893.0,
'class_4_n_examples': 960.0,
'class_5_n_examples': 884.0}
```
### Setting the precision of the reported results
Finally, `run_all_attacks_and_create_summary` takes one extra keyword argument
`decimals`, expecting a positive integer. This sets the precision of all result
values as the number of decimals to report. It defaults to 4.
## Run all attacks and get all outputs
With the `run_all_attacks` function, one can run all implemented attacks on all
possible subsets of the data (all examples, split by class, split by confidence
deciles, misclassified only). This function returns a relatively large
dictionary with all attack results. This is the most detailed information one
could get about these types of membership inference attacks (besides plots for
each attack, see next section.) This is useful if you know exactly what you're
looking for.
> NOTE: The `run_all_attacks` function takes as an additional argument which
> trained attackers to run. In the `run_all_attacks_and_create_summary`, only
> logistic regression (`lr`) is trained as a binary classifier to distinguish
> in-training form out-of-training examples. In addition, with the
> `attack_classifiers` argument, one can add multi-layered perceptrons (`mlp`),
> random forests (`rf`), and k-nearest-neighbors (`knn`) or any subset thereof
> for the attack models. Note that these classifiers may not converge.
```python
mia.run_all_attacks(loss_train, loss_test, logits_train, logits_test,
labels_train, labels_test,
attack_classifiers=('lr', 'mlp', 'rf', 'knn'))
```
Again, `run_all_attacks` can be called on all combinations of losses, logits,
probabilities, and labels as long as at least either losses or logits
(probabilities) are provided.
## Fine grained control over individual attacks and plots
The `run_attack` function exposes the underlying workhorse of the
`run_all_attacks` and `run_all_attacks_and_create_summary` functionality. It
allows for fine grained control of which attacks to run individually.
As another key feature, this function also exposes options to store receiver
operator curve plots for the different attacks as well as histograms of losses
or the maximum logits/probabilities. Finally, we can also store all results
(including the values to reproduce the plots) to colossus.
All options are explained in detail in the doc string of the `run_attack`
function.
For example, to run a simple threshold attack on the losses only and store plots
and result data to colossus, run
```python
data_path = '/Users/user/Desktop/test/' # set to None to not store data
figure_path = '/Users/user/Desktop/test/' # set to None to not store figures
mia.attack(loss_train=loss_train,
loss_test=loss_test,
metric='auc',
output_directory=data_path,
figure_directory=figure_path)
```
Among other things, the `run_attack` functionality allows to control:
* which metrics to output (`metric` argument, using `auc` or `advantage` or
both)
* which classifiers (logistic regression, multi-layered perceptrons, random
forests) to train as attackers beyond the simple threshold attacks
(`attack_classifiers`)
* to only attack a specific (set of) classes (`by_class`)
* to only attack specific percentiles of the data (`by_percentile`).
Percentiles here are computed by looking at the largest logit or probability
for each example, i.e., how confident the model is in its prediction.
* to only attack the misclassified examples (`only_misclassified`)
* not to balance examples between the in-training and out-of-training examples
using `balance`. By default an equal number of examples from train and test
are selected for the attacks (whichever is smaller).
* the test set size for trained attacks (`test_size`). When a classifier is
trained to distinguish between train and test examples, a train-test split
for that classifier itself is required.
* for the train-test split as well as for the class balancing randomness is
used with a seed specified by `random_state`.
## Contact
Reach out to tf-privacy@google.com and let us know how youre using this module.
Were keen on hearing your stories, feedback, and suggestions!
## Copyright
Copyright 2020 - Google LLC

View file

@ -0,0 +1,13 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,716 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Code that runs membership inference attacks based on the model outputs."""
import collections
import io
import os
import re
from typing import Text, Dict, Iterable, Tuple, Union, Any
from absl import logging
import numpy as np
from scipy import special
from tensorflow_privacy.privacy.membership_inference_attack import plotting
from tensorflow_privacy.privacy.membership_inference_attack import trained_attack_models
from tensorflow_privacy.privacy.membership_inference_attack import utils
from os import mkdir
ArrayDict = Dict[Text, np.ndarray]
FloatDict = Dict[Text, float]
AnyDict = Dict[Text, Any]
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
MetricNames = Union[Text, Iterable[Text]]
def _get_vulnerabilities(result: ArrayDict, metrics: MetricNames) -> FloatDict:
"""Gets the vulnerabilities according to the chosen metrics for all attacks."""
vulns = {}
if isinstance(metrics, str):
metrics = [metrics]
for k in result:
for metric in metrics:
if k.endswith(metric.lower()) or k.endswith('n_examples'):
vulns[k] = float(result[k])
return vulns
def _get_maximum_vulnerability(
attack_result: FloatDict,
metrics: MetricNames,
filterby: Text = '') -> Dict[Text, Dict[Text, Union[Text, float]]]:
"""Returns the worst vulnerability according to the chosen metrics of all attacks."""
vulns = {}
if isinstance(metrics, str):
metrics = [metrics]
for metric in metrics:
best_attack_value = -np.inf
for k in attack_result:
if (k.startswith(filterby.lower()) and k.endswith(metric.lower()) and
'train' not in k):
if float(attack_result[k]) > best_attack_value:
best_attack_value = attack_result[k]
best_attacker = k
if best_attack_value > -np.inf:
newkey = filterby + '-' + metric if filterby else metric
vulns[newkey] = {'value': best_attack_value, 'attacker': best_attacker}
return vulns
def _get_maximum_class_gap_or_none(result: FloatDict,
metrics: MetricNames) -> FloatDict:
"""Returns the biggest and smallest vulnerability and the gap across classes."""
gaps = {}
if isinstance(metrics, str):
metrics = [metrics]
for metric in metrics:
hi = -np.inf
lo = np.inf
hi_idx, lo_idx = -1, -1
for k in result:
if (k.startswith('class') and k.endswith(metric.lower()) and
'train' not in k):
if float(result[k]) > hi:
hi = float(result[k])
hi_idx = int(re.findall(r'class_(\d+)_', k)[0])
if float(result[k]) < lo:
lo = float(result[k])
lo_idx = int(re.findall(r'class_(\d+)_', k)[0])
if lo - hi < np.inf:
gaps['max_class_gap_' + metric] = hi - lo
gaps[f'class_{hi_idx}_' + metric] = hi
gaps[f'class_{lo_idx}_' + metric] = lo
gaps['max_vuln_class_' + metric] = hi_idx
gaps['min_vuln_class_' + metric] = lo_idx
return gaps
# ------------------------------------------------------------------------------
# Attacks
# ------------------------------------------------------------------------------
def _run_threshold_loss_attack(features: ArrayDict,
figure_file_prefix: Text = '',
figure_directory: Text = None) -> ArrayDict:
"""Runs the threshold attack on the loss."""
logging.info('Run threshold attack on loss...')
is_train = features['is_train']
attack_prefix = 'thresh_loss'
tmp_results = utils.compute_performance_metrics(is_train, -features['loss'])
if figure_directory is not None:
figpath = os.path.join(figure_directory,
figure_file_prefix + attack_prefix + '.png')
plotting.save_plot(
plotting.plot_curve_with_area(
tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'),
figpath)
figpath = os.path.join(figure_directory,
figure_file_prefix + attack_prefix + '_hist.png')
plotting.save_plot(
plotting.plot_histograms(
features['loss'][is_train == 1],
features['loss'][is_train == 0],
xlabel='loss'), figpath)
return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
def _run_threshold_attack_maxlogit(features: ArrayDict,
figure_file_prefix: Text = '',
figure_directory: Text = None) -> ArrayDict:
"""Runs the threshold attack on the maximum logit."""
is_train = features['is_train']
preds = np.max(features['logits'], axis=-1)
tmp_results = utils.compute_performance_metrics(is_train, preds)
attack_prefix = 'thresh_maxlogit'
if figure_directory is not None:
figpath = os.path.join(figure_directory,
figure_file_prefix + attack_prefix + '.png')
plotting.save_plot(
plotting.plot_curve_with_area(
tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'),
figpath)
figpath = os.path.join(figure_directory,
figure_file_prefix + attack_prefix + '_hist.png')
plotting.save_plot(
plotting.plot_histograms(
preds[is_train == 1], preds[is_train == 0], xlabel='loss'), figpath)
return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
def _run_trained_attack(attack_classifier: Text,
data: Dataset,
attack_prefix: Text,
figure_file_prefix: Text = '',
figure_directory: Text = None) -> ArrayDict:
"""Train a classifier for attack and evaluate it."""
# Train the attack classifier
(x_train, y_train), (x_test, y_test) = data
clf_model = trained_attack_models.choose_model(attack_classifier)
clf_model.fit(x_train, y_train)
# Calculate training set metrics
pred_train = clf_model.predict_proba(x_train)[:, clf_model.classes_ == 1]
results = utils.prepend_to_keys(
utils.compute_performance_metrics(y_train, pred_train),
attack_prefix + 'train_')
# Calculate test set metrics
pred_test = clf_model.predict_proba(x_test)[:, clf_model.classes_ == 1]
results.update(
utils.prepend_to_keys(
utils.compute_performance_metrics(y_test, pred_test),
attack_prefix + 'test_'))
if figure_directory is not None:
figpath = os.path.join(figure_directory,
figure_file_prefix + attack_prefix[:-1] + '.png')
plotting.save_plot(
plotting.plot_curve_with_area(
results[attack_prefix + 'test_fpr'],
results[attack_prefix + 'test_tpr'],
xlabel='FPR',
ylabel='TPR'), figpath)
return results
def _run_attacks_and_plot(features: ArrayDict,
attacks: Iterable[Text],
attack_classifiers: Iterable[Text],
balance: bool,
test_size: float,
random_state: int,
figure_file_prefix: Text = '',
figure_directory: Text = None) -> ArrayDict:
"""Runs the specified attacks on the provided data."""
if balance:
try:
features = utils.subsample_to_balance(features, random_state)
except RuntimeError:
logging.info('Not enough remaining data for attack: Empty results.')
return {}
result = {}
# -------------------- Simple threshold attacks
if 'thresh_loss' in attacks:
result.update(
_run_threshold_loss_attack(features, figure_file_prefix,
figure_directory))
if 'thresh_maxlogit' in attacks:
result.update(
_run_threshold_attack_maxlogit(features, figure_file_prefix,
figure_directory))
# -------------------- Run learned attacks
# TODO(b/157632603): Add a prefix (for example 'trained_') for attacks which
# use classifiers to distinguish from threshould attacks.
if 'logits' in attacks:
data = utils.get_train_test_split(
features, add_loss=False, test_size=test_size)
for clf in attack_classifiers:
logging.info('Train %s on %d logits', clf, data[0][0].shape[1])
attack_prefix = f'{clf}_logits_'
result.update(
_run_trained_attack(clf, data, attack_prefix, figure_file_prefix,
figure_directory))
if 'logits_loss' in attacks:
data = utils.get_train_test_split(
features, add_loss=True, test_size=test_size)
for clf in attack_classifiers:
logging.info('Train %s on %d logits + loss', clf, data[0][0].shape[1])
attack_prefix = f'{clf}_logits_loss_'
result.update(
_run_trained_attack(clf, data, attack_prefix, figure_file_prefix,
figure_directory))
return result
def run_attack(loss_train: np.ndarray = None,
loss_test: np.ndarray = None,
logits_train: np.ndarray = None,
logits_test: np.ndarray = None,
labels_train: np.ndarray = None,
labels_test: np.ndarray = None,
attack_classifiers: Iterable[Text] = None,
only_misclassified: bool = False,
by_class: Union[bool, Iterable[int], int] = False,
by_percentile: Union[bool, Iterable[int], int] = False,
figure_directory: Text = None,
output_directory: Text = None,
metric: MetricNames = 'auc',
balance: bool = True,
test_size: float = 0.2,
random_state: int = 0) -> FloatDict:
"""Run membership inference attack(s).
Based only on specific outputs of a machine learning model on some examples
used for training (train) and some examples not used for training (test), run
membership inference attacks that try to discriminate training from test
inputs based only on the model outputs.
While all inputs are optional, at least one train/test pair is required to run
any attacks (either losses or logits/probabilities).
Note that one can equally provide output probabilities instead of logits in
the logits_train / logits_test arguments.
We measure the vulnerability of the model via the area under the ROC-curve
(auc) or via max |fpr - tpr| (advantage) of the attack classifier. These
measures are very closely related and may look almost indistinguishable.
This function provides relatively fine grained control and outputs detailed
results. For a higher-level wrapper with sane internal default settings and
distilled output results, see `run_all_attacks`.
Via the `figure_directory` argument and the `output_directory` argument more
detailed information as well as roc-curve plots can optionally be stored to
disk.
If `loss_train` and `loss_test` are provided we run:
- simple threshold attack on the loss
If `logits_train` and `logits_test` are provided we run:
- simple threshold attack on the top logit
- if `attack_classifiers` is not None and no losses are provided: train the
specified classifiers on the top 10 logits (or all logits if there are
less than 10)
- if `attack_classifiers` is not None and losses are provided: train the
specified classifiers on the top 10 logits (or all logits if there are
less than 10) and the loss
Args:
loss_train: A 1D array containing the individual scalar losses for examples
used during training.
loss_test: A 1D array containing the individual scalar losses for examples
not used during training.
logits_train: A 2D array (n_train, n_classes) of the individual logits or
output probabilities of examples used during training.
logits_test: A 2D array (n_test, n_classes) of the individual logits or
output probabilities of examples not used during training.
labels_train: The true labels of the training examples. Labels are only
needed when `by_class` is specified (i.e., not False).
labels_test: The true labels of the test examples. Labels are only needed
when `by_class` is specified (i.e., not False).
attack_classifiers: Attack classifiers to train beyond simple thresholding
that require training a simple binary ML classifier. This argument is
ignored if logits are not provided. Classifiers can be 'lr' for logistic
regression, 'mlp' for multi-layered perceptron, 'rf' for random forests,
or 'knn' for k-nearest-neighbors. If 'None', don't train classifiers
beyond simple thresholding.
only_misclassified: Run and evaluate attacks only on misclassified examples.
Must specify `labels_train`, `labels_test`, `logits_train` and
`logits_test` to use this. If this is True, `by_class` and `by_percentile`
are ignored.
by_class: This argument determines whether attacks are run on the entire
data, or on examples grouped by their class label. If `True`, all attacks
are run separately for each class. If `by_class` is a single integer, run
attacks for this class only. If `by_class` is an iterable of integers, run
all attacks for each of the specified class labels separately. Only used
if `labels_train` and `labels_test` are specified. If `by_class` is
specified (not False), `by_percentile` is ignored. Ignored if
`only_misclassified` is True.
by_percentile: This argument determines whether attacks are run on the
entire data, or separately for examples where the most likely class
prediction is within a given percentile of all maximum predicitons. If
`True`, all attacks are run separately for the examples with max
probabilities within the ten deciles. If `by_precentile` is a single int
between 0 and 100, run attacks only for examples with confidence within
this percentile. If `by_percentile` is an iterable of ints between 0 and
100, run all attacks for each of the specified percentiles separately.
Ignored if `by_class` is specified. Ignored if `logits_train` and
`logits_test` are not specified. Ignored if `only_misclassified` is True.
figure_directory: Where to store ROC-curve plots and histograms. If `None`,
don't create plots.
output_directory: Where to store detailed result data for all run attacks.
If `None`, don't store detailed result data.
metric: Available vulnerability metrics are 'auc' or 'advantage' for the
area under the ROC curve or the advantage (max |tpr - fpr|). Specify
either one of them or both.
balance: Whether to use the same number of train and test samples (by
randomly subsampling whichever happens to be larger).
test_size: The fraction of the input data to use for the evaluation of
trained ML attacks. This argument is ignored, if either attack_classifiers
is None, or no logits are provided.
random_state: Random seed for reproducibility. Only used if attack models
are trained.
Returns:
results: Dictionary with the chosen vulnerability metric(s) for all ran
attacks.
"""
attacks = []
features = {}
# ---------- Check available data ----------
if ((loss_train is None or loss_test is None) and
(logits_train is None or logits_test is None)):
raise ValueError(
'Need at least train and test for loss or train and test for logits.')
# ---------- If losses are provided ----------
if loss_train is not None and loss_test is not None:
if loss_train.ndim != 1 or loss_test.ndim != 1:
raise ValueError('Losses must be 1D arrays.')
features['is_train'] = np.concatenate(
(np.ones(len(loss_train)), np.zeros(len(loss_test))),
axis=0).astype(int)
features['loss'] = np.concatenate((loss_train.ravel(), loss_test.ravel()),
axis=0)
attacks.append('thresh_loss')
# ---------- If logits are provided ----------
if logits_train is not None and logits_test is not None:
assert logits_train.ndim == 2 and logits_test.ndim == 2, \
'Logits must be 2D arrays.'
assert logits_train.shape[1] == logits_test.shape[1], \
'Train and test logits must agree along axis 1 (number of classes).'
if 'is_train' in features:
assert (loss_train.shape[0] == logits_train.shape[0] and
loss_test.shape[0] == logits_test.shape[0]), \
'Number of examples must match between loss and logits.'
else:
features['is_train'] = np.concatenate(
(np.ones(logits_train.shape[0]), np.zeros(logits_test.shape[0])),
axis=0).astype(int)
attacks.append('thresh_maxlogit')
features['logits'] = np.concatenate((logits_train, logits_test), axis=0)
if attack_classifiers:
attacks.append('logits')
if 'loss' in features:
attacks.append('logits_loss')
# ---------- If labels are provided ----------
if labels_train is not None and labels_test is not None:
if labels_train.ndim != 1 or labels_test.ndim != 1:
raise ValueError('Losses must be 1D arrays.')
if 'loss' in features:
assert (loss_train.shape[0] == labels_train.shape[0] and
loss_test.shape[0] == labels_test.shape[0]), \
'Number of examples must match between loss and labels.'
else:
assert (logits_train.shape[0] == labels_train.shape[0] and
logits_test.shape[0] == labels_test.shape[0]), \
'Number of examples must match between logits and labels.'
features['label'] = np.concatenate((labels_train, labels_test), axis=0)
# ---------- Data subsampling or filtering ----------
filtertype = None
filtervals = [None]
if only_misclassified:
if (labels_train is None or labels_test is None or logits_train is None or
logits_test is None):
raise ValueError('Must specify labels_train, labels_test, logits_train, '
'and logits_test for the only_misclassified option.')
filtertype = 'misclassified'
elif by_class:
if labels_train is None or labels_test is None:
raise ValueError('Must specify labels_train and labels_test when using '
'the by_class option.')
if isinstance(by_class, bool):
filtervals = list(set(labels_train) | set(labels_test))
elif isinstance(by_class, int):
filtervals = [by_class]
elif isinstance(by_class, collections.Iterable):
filtervals = list(by_class)
filtertype = 'class'
elif by_percentile:
if logits_train is None or logits_test is None:
raise ValueError('Must specify logits_train and logits_test when using '
'the by_percentile option.')
if isinstance(by_percentile, bool):
filtervals = list(range(10, 101, 10))
elif isinstance(by_percentile, int):
filtervals = [by_percentile]
elif isinstance(by_percentile, collections.Iterable):
filtervals = [int(percentile) for percentile in by_percentile]
filtertype = 'percentile'
# ---------- Need to create figure directory? ----------
if figure_directory is not None:
mkdir(figure_directory)
# ---------- Actually run attacks and plot if required ----------
logging.info('Selecting %s with values %s', filtertype, filtervals)
num = None
result = {}
for filterval in filtervals:
if filtertype is None:
tmp_features = features
elif filtertype == 'misclassified':
idx = features['label'] != np.argmax(features['logits'], axis=-1)
tmp_features = utils.select_indices(features, idx)
num = np.sum(idx)
elif filtertype == 'class':
idx = features['label'] == filterval
tmp_features = utils.select_indices(features, idx)
num = np.sum(idx)
elif filtertype == 'percentile':
certainty = np.max(special.softmax(features['logits'], axis=-1), axis=-1)
idx = certainty <= np.percentile(certainty, filterval)
tmp_features = utils.select_indices(features, idx)
prefix = f'{filtertype}_' if filtertype is not None else ''
prefix += f'{filterval}_' if filterval is not None else ''
tmp_result = _run_attacks_and_plot(tmp_features, attacks,
attack_classifiers, balance, test_size,
random_state, prefix, figure_directory)
if num is not None:
tmp_result['n_examples'] = float(num)
if tmp_result:
result.update(utils.prepend_to_keys(tmp_result, prefix))
# ---------- Store data ----------
if output_directory is not None:
mkdir(output_directory)
resultpath = os.path.join(output_directory, 'attack_results.npz')
logging.info('Store aggregate results at %s.', resultpath)
with open(resultpath, 'wb') as fp:
io_buffer = io.BytesIO()
np.savez(io_buffer, **result)
fp.write(io_buffer.getvalue())
return _get_vulnerabilities(result, metric)
def run_all_attacks(loss_train: np.ndarray = None,
loss_test: np.ndarray = None,
logits_train: np.ndarray = None,
logits_test: np.ndarray = None,
labels_train: np.ndarray = None,
labels_test: np.ndarray = None,
attack_classifiers: Iterable[Text] = ('lr', 'mlp', 'rf',
'knn'),
decimals: Union[int, None] = 4) -> FloatDict:
"""Runs all possible membership inference attacks.
Check 'run_attack' for detailed information of how attacks are performed
and evaluated.
This function internally chooses sane default settings for all attacks and
returns all possible output combinations.
For fine grained control and partial attacks, please see `run_attack`.
Args:
loss_train: A 1D array containing the individual scalar losses for examples
used during training.
loss_test: A 1D array containing the individual scalar losses for examples
not used during training.
logits_train: A 2D array (n_train, n_classes) of the individual logits or
output probabilities of examples used during training.
logits_test: A 2D array (n_test, n_classes) of the individual logits or
output probabilities of examples not used during training.
labels_train: The true labels of the training examples. Labels are only
needed when `by_class` is specified (i.e., not False).
labels_test: The true labels of the test examples. Labels are only needed
when `by_class` is specified (i.e., not False).
attack_classifiers: Which binary classifiers to train (in addition to simple
threshold attacks). This can include 'lr' (logistic regression), 'mlp'
(multi-layered perceptron), 'rf' (random forests), 'knn' (k-nearest
neighbors), which will be trained with cross validation to determine good
hyperparameters.
decimals: Round all float results to this number of decimals. If decimals is
None, don't round.
Returns:
result: dictionary with all attack results
"""
metrics = ['auc', 'advantage']
# Entire data
result = run_attack(
loss_train,
loss_test,
logits_train,
logits_test,
attack_classifiers=attack_classifiers,
metric=metrics)
result = utils.prepend_to_keys(result, 'all_')
# Misclassified examples
if (labels_train is not None and labels_test is not None and
logits_train is not None and logits_test is not None):
result.update(
run_attack(
loss_train,
loss_test,
logits_train,
logits_test,
labels_train,
labels_test,
attack_classifiers=attack_classifiers,
only_misclassified=True,
metric=metrics))
# Split per class
if labels_train is not None and labels_test is not None:
result.update(
run_attack(
loss_train,
loss_test,
logits_train,
logits_test,
labels_train,
labels_test,
by_class=True,
attack_classifiers=attack_classifiers,
metric=metrics))
# Different deciles
if logits_train is not None and logits_test is not None:
result.update(
run_attack(
loss_train,
loss_test,
logits_train,
logits_test,
by_percentile=True,
attack_classifiers=attack_classifiers,
metric=metrics))
if decimals is not None:
result = {k: round(v, decimals) for k, v in result.items()}
return result
def run_all_attacks_and_create_summary(
loss_train: np.ndarray = None,
loss_test: np.ndarray = None,
logits_train: np.ndarray = None,
logits_test: np.ndarray = None,
labels_train: np.ndarray = None,
labels_test: np.ndarray = None,
return_dict: bool = True,
decimals: Union[int, None] = 4) -> Union[Text, Tuple[Text, AnyDict]]:
"""Runs all possible membership inference attack(s) and distill results.
Check 'run_attack' for detailed information of how attacks are performed
and evaluated.
This function internally chooses sane default settings for all attacks and
returns all possible output combinations.
For fine grained control and partial attacks, please see `run_attack`.
Args:
loss_train: A 1D array containing the individual scalar losses for examples
used during training.
loss_test: A 1D array containing the individual scalar losses for examples
not used during training.
logits_train: A 2D array (n_train, n_classes) of the individual logits or
output probabilities of examples used during training.
logits_test: A 2D array (n_test, n_classes) of the individual logits or
output probabilities of examples not used during training.
labels_train: The true labels of the training examples. Labels are only
needed when `by_class` is specified (i.e., not False).
labels_test: The true labels of the test examples. Labels are only needed
when `by_class` is specified (i.e., not False).
return_dict: Whether to also return a dictionary with the results summarized
in the summary string.
decimals: Round all float results to this number of decimals. If decimals is
None, don't round.
Returns:
summarystring: A string with natural language summary of the attacks. In the
summary string printed numbers will be rounded to `decimal` decimals if
provided, otherwise will round to 3 diits by default for readability.
result: a dictionary with all the distilled attack information summarized
in the summarystring
"""
summary = []
metrics = ['auc', 'advantage']
attack_classifiers = ['lr', 'rf', 'mlp', 'knn']
results = run_all_attacks(
loss_train,
loss_test,
logits_train,
logits_test,
labels_train,
labels_test,
attack_classifiers=attack_classifiers,
decimals=None)
output = _get_maximum_vulnerability(results, metrics, filterby='all')
if decimals is not None:
strdec = decimals
else:
strdec = 4
for metric in metrics:
summary.append(f'========== {metric.upper()} ==========')
best_value = output['all-' + metric]['value']
best_attacker = output['all-' + metric]['attacker']
summary.append(f'The best attack ({best_attacker}) achieved an {metric} of '
f'{best_value:.{strdec}f}.')
summary.append('')
classgap = _get_maximum_class_gap_or_none(results, metrics)
if classgap:
output.update(classgap)
for metric in metrics:
summary.append(f'========== {metric.upper()} per class ==========')
hi_idx = output[f'max_vuln_class_{metric}']
lo_idx = output[f'min_vuln_class_{metric}']
hi = output[f'class_{hi_idx}_{metric}']
lo = output[f'class_{lo_idx}_{metric}']
gap = output[f'max_class_gap_{metric}']
summary.append(f'The most vulnerable class {hi_idx} has {metric} of '
f'{hi:.{strdec}f}.')
summary.append(f'The least vulnerable class {lo_idx} has {metric} of '
f'{lo:.{strdec}f}.')
summary.append(f'=> The maximum gap between class vulnerabilities is '
f'{gap:.{strdec}f}.')
summary.append('')
misclassified = _get_maximum_vulnerability(
results, metrics, filterby='misclassified')
if misclassified:
for metric in metrics:
best_attacker = misclassified['misclassified-' + metric]['attacker']
summary.append(f'========== {metric.upper()} for misclassified '
'==========')
summary.append('Among misclassified examples, the best attack '
f'({best_attacker}) achieved an {metric} of '
f'{best_value:.{strdec}f}.')
summary.append('')
output.update(misclassified)
n_examples = {k: v for k, v in results.items() if k.endswith('n_examples')}
if n_examples:
output.update(n_examples)
# Flatten remaining dicts in output
fresh_output = {}
for k, v in output.items():
if isinstance(v, dict):
if k.startswith('all'):
fresh_output[k[4:]] = v['value']
fresh_output['best_attacker_' + k[4:]] = v['attacker']
else:
fresh_output[k] = v
output = fresh_output
if decimals is not None:
for k, v in output.items():
if isinstance(v, float):
output[k] = round(v, decimals)
summary = '\n'.join(summary)
if return_dict:
return summary, output
else:
return summary

View file

@ -0,0 +1,307 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
from absl.testing import absltest
import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
def get_result_dict():
"""Get an example result dictionary."""
return {
'test_n_examples': np.ones(1),
'test_examples': np.zeros(1),
'test_auc': np.ones(1),
'test_advantage': np.ones(1),
'all_0-metric': np.array([1]),
'all_1-metric': np.array([2]),
'test_2-metric': np.array([3]),
'test_score': np.array([4]),
}
def get_test_inputs():
"""Get example inputs for attacks."""
n_train = n_test = 500
rng = np.random.RandomState(4)
loss_train = rng.randn(n_train) - 0.4
loss_test = rng.randn(n_test) + 0.4
logits_train = rng.randn(n_train, 5) + 0.2
logits_test = rng.randn(n_test, 5) - 0.2
labels_train = np.array([i % 5 for i in range(n_train)])
labels_test = np.array([(3 * i) % 5 for i in range(n_test)])
return (loss_train, loss_test, logits_train, logits_test,
labels_train, labels_test)
class GetVulnerabilityTest(absltest.TestCase):
def test_get_vulnerabilities(self):
"""Test extraction of vulnerability scores."""
testdict = get_result_dict()
for key in ['auc', 'advantage']:
res = mia._get_vulnerabilities(testdict, key)
self.assertLen(res, 2)
self.assertEqual(res[f'test_{key}'], 1)
self.assertEqual(res['test_n_examples'], 1)
res = mia._get_vulnerabilities(testdict, ['auc', 'advantage'])
self.assertLen(res, 3)
self.assertEqual(res['test_auc'], 1)
self.assertEqual(res['test_advantage'], 1)
self.assertEqual(res['test_n_examples'], 1)
class GetMaximumVulnerabilityTest(absltest.TestCase):
def test_get_maximum_vulnerability(self):
"""Test extraction of maximum vulnerability score."""
testdict = get_result_dict()
for i in range(3):
key = f'{i}-metric'
res = mia._get_maximum_vulnerability(testdict, key)
self.assertLen(res, 1)
self.assertEqual(res[key]['value'], i + 1)
if i < 2:
self.assertEqual(res[key]['attacker'], f'all_{i}-metric')
else:
self.assertEqual(res[key]['attacker'], 'test_2-metric')
res = mia._get_maximum_vulnerability(testdict, 'metric')
self.assertLen(res, 1)
self.assertEqual(res['metric']['value'], 3)
res = mia._get_maximum_vulnerability(testdict, ['metric'],
filterby='all')
self.assertLen(res, 1)
self.assertEqual(res['all-metric']['value'], 2)
res = mia._get_maximum_vulnerability(testdict, ['metric', 'score'])
self.assertLen(res, 2)
self.assertEqual(res['metric']['value'], 3)
self.assertEqual(res['score']['value'], 4)
self.assertEqual(res['score']['attacker'], 'test_score')
class ThresholdAttackLossTest(absltest.TestCase):
def test_threshold_attack_loss(self):
"""Test simple threshold attack on loss."""
features = {
'loss': np.zeros(10),
'is_train': np.concatenate((np.zeros(5), np.ones(5))),
}
res = mia._run_threshold_loss_attack(features)
for k in res:
self.assertStartsWith(k, 'thresh_loss')
self.assertEqual(res['thresh_loss_auc'], 0.5)
self.assertEqual(res['thresh_loss_advantage'], 0.0)
rng = np.random.RandomState(4)
n_train = 1000
n_test = 500
loss_train = rng.randn(n_train) - 0.4
loss_test = rng.randn(n_test) + 0.4
features = {
'loss': np.concatenate((loss_train, loss_test)),
'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))),
}
res = mia._run_threshold_loss_attack(features)
self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75)
self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35)
class ThresholdAttackMaxlogitTest(absltest.TestCase):
def test_threshold_attack_maxlogits(self):
"""Test simple threshold attack on maximum logit."""
features = {
'logits': np.eye(10, 14),
'is_train': np.concatenate((np.zeros(5), np.ones(5))),
}
res = mia._run_threshold_attack_maxlogit(features)
for k in res:
self.assertStartsWith(k, 'thresh_maxlogit')
self.assertEqual(res['thresh_maxlogit_auc'], 0.5)
self.assertEqual(res['thresh_maxlogit_advantage'], 0.0)
rng = np.random.RandomState(4)
n_train = 1000
n_test = 500
logits_train = rng.randn(n_train, 12) + 0.2
logits_test = rng.randn(n_test, 12) - 0.2
features = {
'logits': np.concatenate((logits_train, logits_test), axis=0),
'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))),
}
res = mia._run_threshold_attack_maxlogit(features)
self.assertBetween(res['thresh_maxlogit_auc'], 0.7, 0.75)
self.assertBetween(res['thresh_maxlogit_advantage'], 0.3, 0.35)
class TrainedAttackTrivialTest(absltest.TestCase):
def test_trained_attack(self):
"""Test trained attacks."""
# Trivially easy problem
x_train, x_test = np.ones((500, 3)), np.ones((20, 3))
x_train[:200] *= -1
x_test[:8] *= -1
y_train, y_test = np.ones(500).astype(int), np.ones(20).astype(int)
y_train[:200] = 0
y_test[:8] = 0
data = (x_train, y_train), (x_test, y_test)
for clf in ['lr', 'rf', 'mlp', 'knn']:
res = mia._run_trained_attack(clf, data, attack_prefix='a-')
self.assertEqual(res['a-train_auc'], 1)
self.assertEqual(res['a-test_auc'], 1)
self.assertEqual(res['a-train_advantage'], 1)
self.assertEqual(res['a-test_advantage'], 1)
class TrainedAttackRandomFeaturesTest(absltest.TestCase):
def test_trained_attack(self):
"""Test trained attacks."""
# Random labels and features
rng = np.random.RandomState(4)
x_train, x_test = rng.randn(500, 3), rng.randn(500, 3)
y_train = rng.binomial(1, 0.5, size=(500,))
y_test = rng.binomial(1, 0.5, size=(500,))
data = (x_train, y_train), (x_test, y_test)
for clf in ['lr', 'rf', 'mlp', 'knn']:
res = mia._run_trained_attack(clf, data, attack_prefix='a-')
self.assertBetween(res['a-train_auc'], 0.5, 1.)
self.assertBetween(res['a-test_auc'], 0.4, 0.6)
self.assertBetween(res['a-train_advantage'], 0., 1.0)
self.assertBetween(res['a-test_advantage'], 0., 0.2)
class AttackLossesTest(absltest.TestCase):
def test_attack(self):
"""Test individual attack function."""
# losses only, both metrics
loss_train, loss_test, _, _, _, _ = get_test_inputs()
res = mia.run_attack(loss_train, loss_test, metric=('auc', 'advantage'))
self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75)
self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35)
class AttackLossesLogitsTest(absltest.TestCase):
def test_attack(self):
"""Test individual attack function."""
# losses and logits, two classifiers, single metric
loss_train, loss_test, logits_train, logits_test, _, _ = get_test_inputs()
res = mia.run_attack(
loss_train,
loss_test,
logits_train,
logits_test,
attack_classifiers=('rf', 'knn'),
metric='auc')
self.assertBetween(res['rf_logits_test_auc'], 0.7, 0.9)
self.assertBetween(res['knn_logits_test_auc'], 0.7, 0.9)
self.assertBetween(res['rf_logits_loss_test_auc'], 0.7, 0.9)
self.assertBetween(res['knn_logits_loss_test_auc'], 0.7, 0.9)
class AttackLossesLabelsByClassTest(absltest.TestCase):
def test_attack(self):
# losses and labels, single metric, split by class
loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs()
n_train = loss_train.shape[0]
n_test = loss_test.shape[0]
res = mia.run_attack(
loss_train,
loss_test,
labels_train=labels_train,
labels_test=labels_test,
by_class=True,
metric='auc')
self.assertLen(res, 10)
for k in res:
self.assertStartsWith(k, 'class_')
if k.endswith('n_examples'):
self.assertEqual(int(res[k]), (n_train + n_test) // 5)
else:
self.assertBetween(res[k], 0.65, 0.75)
class AttackLossesLabelsSingleClassTest(absltest.TestCase):
def test_attack(self):
# losses and labels, both metrics, single class
loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs()
n_train = loss_train.shape[0]
n_test = loss_test.shape[0]
res = mia.run_attack(
loss_train,
loss_test,
labels_train=labels_train,
labels_test=labels_test,
by_class=2,
metric=('auc', 'advantage'))
self.assertLen(res, 3)
for k in res:
self.assertStartsWith(k, 'class_2')
if k.endswith('n_examples'):
self.assertEqual(int(res[k]), (n_train + n_test) // 5)
elif k.endswith('advantage'):
self.assertBetween(res[k], 0.3, 0.5)
elif k.endswith('auc'):
self.assertBetween(res[k], 0.7, 0.75)
class AttackLogitsLabelsMisclassifiedTest(absltest.TestCase):
def test_attack(self):
# logits and labels, single metric, single classifier, misclassified only
(_, _, logits_train, logits_test,
labels_train, labels_test) = get_test_inputs()
res = mia.run_attack(
logits_train=logits_train,
logits_test=logits_test,
labels_train=labels_train,
labels_test=labels_test,
only_misclassified=True,
attack_classifiers=('lr',),
metric='advantage')
self.assertBetween(res['misclassified_lr_logits_test_advantage'], 0.3, 0.8)
self.assertEqual(res['misclassified_n_examples'], 802)
class AttackLogitsByPrecentileTest(absltest.TestCase):
def test_attack(self):
# only logits, single metric, no classifiers, split by deciles
_, _, logits_train, logits_test, _, _ = get_test_inputs()
res = mia.run_attack(
logits_train=logits_train,
logits_test=logits_test,
by_percentile=True,
metric='auc')
for k in res:
self.assertStartsWith(k, 'percentile')
self.assertBetween(res[k], 0.60, 0.75)
if __name__ == '__main__':
absltest.main()

View file

@ -0,0 +1,80 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Plotting functionality for membership inference attack analysis.
Functions to plot ROC curves and histograms as well as functionality to store
figures to colossus.
"""
from typing import Text, Iterable
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
def save_plot(figure: plt.Figure, path: Text, outformat='png'):
"""Store a figure to disk."""
if path is not None:
with open(path, 'wb') as f:
figure.savefig(f, bbox_inches='tight', format=outformat)
plt.close(figure)
def plot_curve_with_area(x: Iterable[float],
y: Iterable[float],
xlabel: Text = 'x',
ylabel: Text = 'y') -> plt.Figure:
"""Plot the curve defined by inputs and the area under the curve.
All entries of x and y are required to lie between 0 and 1.
For example, x could be recall and y precision, or x is fpr and y is tpr.
Args:
x: Values on x-axis (1d)
y: Values on y-axis (must be same length as x)
xlabel: Label for x axis
ylabel: Label for y axis
Returns:
The matplotlib figure handle
"""
fig = plt.figure()
plt.plot([0, 1], [0, 1], 'k', lw=1.0)
plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.legend()
return fig
def plot_histograms(train: Iterable[float],
test: Iterable[float],
xlabel: Text = 'x',
thresh: float = None) -> plt.Figure:
"""Plot histograms of training versus test metrics."""
xmin = min(np.min(train), np.min(test))
xmax = max(np.max(train), np.max(test))
bins = np.linspace(xmin, xmax, 100)
fig = plt.figure()
plt.hist(test, bins=bins, density=True, alpha=0.5, label='test', log='y')
plt.hist(train, bins=bins, density=True, alpha=0.5, label='train', log='y')
if thresh is not None:
plt.axvline(thresh, c='r', label=f'threshold = {thresh:.3f}')
plt.xlabel(xlabel)
plt.ylabel('normalized counts (density)')
plt.legend()
return fig

View file

@ -0,0 +1,217 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""This module contains code to run attacks on previous model outputs.
Provided a path to a dataset of model outputs (logits, output probabilities,
losses, labels, predictions, membership indicators (is train or not)), we train
supervised binary classifiers using variable sets of features to distinguish
training from testing examples. We also provide threshold attacks, i.e., simply
thresholing losses or the maximum probability/logit to obtain binary
predictions.
The input data is assumed to be a tf.example proto stored with RecordIO (.rio).
For example, outputs in an accepted format are typically produced by the
`extract` script in the `extract` directory.
We run various attacks on each of the full datasets, split by class, split by
percentile of the most certain prediction and only on misclassified examples and
record the area under the receiver operator curve as well as the attack
advantage (i.e., max |tpr - fpr|) as vulnerability metrics. For all metrics
recorded, see the doc string of `membership_inference_attack.all_attacks`.
In addition, we record the overall training and test accuracy and loss of the
original image classifier. All these results are collected in a single
dictionary with descriptive keys. If there exist multiple model checkpoints (at
different training epochs), the results for each checkpoint are concatenated,
such that the dictionary keys stay the same, but the values contain arrays (the
size being the number of checkpoints). This overall result dicitonary is then
stored as a binary (and compressed) numpy file: .npz.
This file is stored either in the provided output path. If that is the empty
string, it is stored on the same level as the inputdir with the chosen name.
Using `attack_results.npz` by default.
"""
# Example usage:
python run_attack.py --dataset=cifar10 --inputdir="attack_data"
The results are then stored at ./attack_data
import io
import os
import re
from typing import Text, Dict
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.google as tf
import tensorflow_datasets as tfds
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack import utils
from glob import glob
Result = Dict[Text, np.ndarray]
FLAGS = flags.FLAGS
flags.DEFINE_float('test_size', 0.2,
'Fraction of attack data used for the test set.')
flags.DEFINE_string('dataset', 'cifar10', 'The dataset to use.')
flags.DEFINE_string(
'output', '', 'The path where to store the results. '
'If empty string, store on same level as `inputdir` using '
'the name specified in the result_name flag.')
flags.DEFINE_string('result_name', 'attack_results.npz',
'The name of the output npz file with the attack results.')
flags.DEFINE_string(
'inputdir',
'attack_data',
'The input directory containing the attack datasets.')
flags.DEFINE_integer('seed', 43, 'Random seed to ensure same data splits.')
# ------------------------------------------------------------------------------
# Load and select features for attacks
# ------------------------------------------------------------------------------
def load_all_features(data_path: Text) -> Result:
"""Extract the selected features from a given dataset."""
if FLAGS.dataset == 'cifar100':
num_classes = 100
elif FLAGS.dataset in ['cifar10', 'mnist']:
num_classes = 10
else:
raise ValueError(f'Unknown dataset {FLAGS.dataset}')
features = {
'logits': tf.FixedLenFeature((num_classes,), tf.float32),
'prob': tf.FixedLenFeature((num_classes,), tf.float32),
'loss': tf.FixedLenFeature([], tf.float32),
'is_train': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64),
'prediction': tf.FixedLenFeature([], tf.int64),
}
dataset = tf.data.RecordIODataset(data_path)
results = {k: [] for k in features}
ds = dataset.map(lambda x: tf.parse_single_example(x, features))
for example in tfds.as_numpy(ds):
for k in results:
results[k].append(example[k])
return utils.to_numpy(results)
# ------------------------------------------------------------------------------
# Run attacks
# ------------------------------------------------------------------------------
def run_all_attacks(data_path: Text):
"""Train all possible attacks on the data at the given path."""
logging.info('Load all features from %s...', data_path)
features = load_all_features(data_path)
for k, v in features.items():
logging.info('%s: %s', k, v.shape)
logging.info('Compute original train/test accuracy and loss...')
train_idx = features['is_train'] == 1
test_idx = np.logical_not(train_idx)
correct = features['label'] == features['prediction']
result = {
'original_train_loss': np.mean(features['loss'][train_idx]),
'original_test_loss': np.mean(features['loss'][test_idx]),
'original_train_acc': np.mean(correct[train_idx]),
'original_test_acc': np.mean(correct[test_idx]),
}
result.update(
mia.run_all_attacks(
loss_train=features['loss'][train_idx],
loss_test=features['loss'][test_idx],
logits_train=features['logits'][train_idx],
logits_test=features['logits'][test_idx],
labels_train=features['label'][train_idx],
labels_test=features['label'][test_idx],
attack_classifiers=('lr', 'mlp', 'rf', 'knn'),
decimals=None))
result = utils.ensure_1d(result)
logging.info('Finished training and evaluating attacks.')
return result
def attacking():
"""Load data and model and extract relevant outputs."""
# ---------- Set result path ----------
if FLAGS.output:
resultpath = FLAGS.output
else:
resultdir = FLAGS.inputdir
if resultdir[-1] == '/':
resultdir = resultdir[:-1]
resultdir = '/'.join(resultdir.split('/')[:-1])
resultpath = os.path.join(resultdir, FLAGS.result_name)
# ---------- Glob attack training sets ----------
logging.info('Glob attack data paths...')
data_paths = sorted(glob(os.path.join(FLAGS.inputdir, '*')))
logging.info('Found %d data paths', len(data_paths))
# ---------- Iterate over attack dataset and train attacks ----------
epochs = []
results = []
for i, datapath in enumerate(data_paths):
logging.info('=' * 80)
logging.info('Attack model %d / %d', i + 1, len(data_paths))
logging.info('=' * 80)
basename = os.path.basename(datapath)
found_ints = re.findall(r'(\d+)', basename)
if len(found_ints) == 1:
epoch = int(found_ints[0])
logging.info('Found integer %d in pathname, interpret as epoch', epoch)
else:
epoch = np.nan
tmp_res = run_all_attacks(datapath)
if tmp_res is not None:
results.append(tmp_res)
epochs.append(epoch)
# ---------- Aggregate and save results ----------
logging.info('Aggregate and combine all results over epochs...')
results = utils.merge_dictionaries(results)
results['epochs'] = np.array(epochs)
logging.info('Store aggregate results at %s.', resultpath)
with open(resultpath, 'wb') as fp:
io_buffer = io.BytesIO()
np.savez(io_buffer, **results)
fp.write(io_buffer.getvalue())
logging.info('Finished attacks.')
def main(argv):
del argv # Unused
attacking()
if __name__ == '__main__':
app.run(main)

View file

@ -0,0 +1,106 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""A collection of sklearn models for binary classification.
This module contains some sklearn pipelines for finding models for binary
classification from a variable number of numerical input features.
These models are used to train binary classifiers for membership inference.
"""
from typing import Text
import numpy as np
from sklearn import ensemble
from sklearn import linear_model
from sklearn import model_selection
from sklearn import neighbors
from sklearn import neural_network
def choose_model(attack_classifier: Text):
"""Choose a model based on a string classifier."""
if attack_classifier == 'lr':
return logistic_regression()
elif attack_classifier == 'mlp':
return mlp()
elif attack_classifier == 'rf':
return random_forest()
elif attack_classifier == 'knn':
return knn()
else:
raise ValueError(f'Unknown attack classifier {attack_classifier}.')
def logistic_regression(verbose: int = 0, n_jobs: int = 1):
"""Setup a logistic regression pipeline with cross-validation."""
lr = linear_model.LogisticRegression(solver='lbfgs')
param_grid = {
'C': np.logspace(-4, 2, 10),
}
pipe = model_selection.GridSearchCV(
lr, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
verbose=verbose)
return pipe
def random_forest(verbose: int = 0, n_jobs: int = 1):
"""Setup a random forest pipeline with cross-validation."""
rf = ensemble.RandomForestClassifier()
n_estimators = [100]
max_features = ['auto', 'sqrt']
max_depth = [5, 10, 20]
max_depth.append(None)
min_samples_split = [2, 5, 10]
min_samples_leaf = [1, 2, 4]
random_grid = {'n_estimators': n_estimators,
'max_features': max_features,
'max_depth': max_depth,
'min_samples_split': min_samples_split,
'min_samples_leaf': min_samples_leaf}
pipe = model_selection.RandomizedSearchCV(
rf, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs,
iid=False, verbose=verbose)
return pipe
def mlp(verbose: int = 0, n_jobs: int = 1):
"""Setup a MLP pipeline with cross-validation."""
mlpmodel = neural_network.MLPClassifier()
param_grid = {
'hidden_layer_sizes': [(64,), (32, 32)],
'solver': ['adam'],
'alpha': [0.0001, 0.001, 0.01],
}
pipe = model_selection.GridSearchCV(
mlpmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
verbose=verbose)
return pipe
def knn(verbose: int = 0, n_jobs: int = 1):
"""Setup a k-nearest neighbors pipeline with cross-validation."""
knnmodel = neighbors.KNeighborsClassifier()
param_grid = {
'n_neighbors': [3, 5, 7],
}
pipe = model_selection.GridSearchCV(
knnmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
verbose=verbose)
return pipe

View file

@ -0,0 +1,218 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utility functions for membership inference attacks."""
from typing import Text, Dict, Union, List, Any, Tuple
import numpy as np
from sklearn import metrics
ArrayDict = Dict[Text, np.ndarray]
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
# ------------------------------------------------------------------------------
# Utilities for managing result dictionaries
# ------------------------------------------------------------------------------
def to_numpy(in_dict: Dict[Text, Any]) -> ArrayDict:
"""Convert values of dict to numpy arrays.
Warning: This may fail if the values cannot be converted to numpy arrays.
Args:
in_dict: A dictionary mapping Text keys to values where the values must be
something that can be converted to a numpy array.
Returns:
a dictionary with the same keys as input with all values converted to numpy
arrays
"""
return {k: np.array(v) for k, v in in_dict.items()}
def ensure_1d(in_dict: Dict[Text, Union[int, float, np.ndarray]]) -> ArrayDict:
"""Ensure all values of a dictionary are at least 1D numpy arrays.
Args:
in_dict: The input dictionary mapping Text keys to numpy arrays or numbers.
Returns:
dictionary with same keys as in_dict and values converted to numpy arrays
with at least one dimension (i.e., pack scalars into arrays)
"""
return {k: np.atleast_1d(v) for k, v in in_dict.items()}
def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
"""Prepend a prefix to all keys of a dictionary.
Args:
in_dict: The input dictionary mapping Text keys to numpy arrays.
prefix: Text which to prepend to each key in in_dict
Returns:
dictionary with same values as in_dict and all keys having prefix prepended
to them
"""
return {prefix + k: v for k, v in in_dict.items()}
# ------------------------------------------------------------------------------
# Subsampling and data selection functionality
# ------------------------------------------------------------------------------
def select_indices(in_dict: ArrayDict, indices: np.ndarray) -> ArrayDict:
"""Subsample all values in the dictionary by the provided indices.
Args:
in_dict: The input dictionary mapping Text keys to numpy array values.
indices: A numpy which can be used to index other arrays, specifying the
indices to subsample from in_dict values.
Returns:
dictionary with same keys as in_dict and subsampled values
"""
return {k: v[indices] for k, v in in_dict.items()}
def merge_dictionaries(res: List[ArrayDict]) -> ArrayDict:
"""Convert iterable of dicts to dict of iterables."""
output = {k: np.empty(0) for k in res[0]}
for k in output:
output[k] = np.concatenate([r[k] for r in res if k in r], axis=0)
return output
def get_features(features: ArrayDict,
feature_name: Text,
top_k: int,
add_loss: bool = False) -> np.ndarray:
"""Combine the specified features into one array.
Args:
features: A dictionary containing all possible features.
feature_name: Which feature to use (logits or prob).
top_k: The number of the top features (of feature_name) to select.
add_loss: Whether to also add the loss as a feature.
Returns:
combined numpy array with the selected features (n_examples, n_features)
"""
if top_k < 1:
raise ValueError('Must select at least one feature.')
feats = np.sort(features[feature_name], axis=-1)[:, :top_k]
if add_loss:
feats = np.concatenate((feats, features['loss'][:, np.newaxis]), axis=-1)
return feats
def subsample_to_balance(features: ArrayDict, random_state: int) -> ArrayDict:
"""Subsample if necessary to balance labels."""
train_idx = features['is_train'] == 1
test_idx = np.logical_not(train_idx)
n0 = np.sum(test_idx)
n1 = np.sum(train_idx)
if n0 < 20 or n1 < 20:
raise RuntimeError('Need at least 20 examples from training and test set.')
np.random.seed(random_state)
if n0 > n1:
use_idx = np.random.choice(np.where(test_idx)[0], n1, replace=False)
use_idx = np.concatenate((use_idx, np.where(train_idx)[0]))
features = {k: v[use_idx] for k, v in features.items()}
elif n0 < n1:
use_idx = np.random.choice(np.where(train_idx)[0], n0, replace=False)
use_idx = np.concatenate((use_idx, np.where(test_idx)[0]))
features = {k: v[use_idx] for k, v in features.items()}
return features
def get_train_test_split(features: ArrayDict, add_loss: bool,
test_size: float) -> Dataset:
"""Get training and test data split."""
y = features['is_train']
n_total = len(y)
n_test = int(test_size * n_total)
perm = np.random.permutation(len(y))
test_idx = perm[:n_test]
train_idx = perm[n_test:]
y_train = y[train_idx]
y_test = y[test_idx]
# We are using 10 top logits as a good default value if there are more than 10
# classes. Typically, there is no significant amount of weight in more than
# 10 logits.
n_logits = min(features['logits'].shape[1], 10)
x = get_features(features, 'logits', n_logits, add_loss)
x_train, x_test = x[train_idx], x[test_idx]
return (x_train, y_train), (x_test, y_test)
# ------------------------------------------------------------------------------
# Computation of the attack metrics
# ------------------------------------------------------------------------------
def compute_performance_metrics(true_labels: np.ndarray,
predictions: np.ndarray,
threshold: float = None) -> ArrayDict:
"""Compute relevant classification performance metrics.
The outout metrics are
1.arrays of thresholds and corresponding true and false positives (fpr, tpr).
2.auc area under fpr-tpr curve.
3.advantage max difference between tpr and fpr.
4.precision/recall/accuracy/f1_score if threshold arg is given.
Args:
true_labels: True labels.
predictions: Predicted probabilities/scores.
threshold: The threshold to use on `predictions` binary classification.
Returns:
A dictionary with relevant metrics which are fully described by their key.
"""
results = {}
if threshold is not None:
results.update({
'precision':
metrics.precision_score(true_labels, predictions > threshold),
'recall':
metrics.recall_score(true_labels, predictions > threshold),
'accuracy':
metrics.accuracy_score(true_labels, predictions > threshold),
'f1_score':
metrics.f1_score(true_labels, predictions > threshold),
})
fpr, tpr, thresholds = metrics.roc_curve(true_labels, predictions)
auc = metrics.auc(fpr, tpr)
advantage = np.max(np.abs(tpr - fpr))
results.update({
'fpr': fpr,
'tpr': tpr,
'thresholds': thresholds,
'auc': auc,
'advantage': advantage,
})
return ensure_1d(results)

View file

@ -0,0 +1,105 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
from absl.testing import absltest
import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import utils
class UtilsTest(absltest.TestCase):
def __init__(self, methodname):
"""Initialize the test class."""
super().__init__(methodname)
rng = np.random.RandomState(33)
logits = rng.uniform(low=0, high=1, size=(1000, 14))
loss = rng.uniform(low=0, high=1, size=(1000,))
is_train = rng.binomial(1, 0.7, size=(1000,))
self.mydict = {'logits': logits, 'loss': loss, 'is_train': is_train}
def test_compute_metrics(self):
"""Test computation of attack metrics."""
true = np.array([0, 0, 0, 1, 1, 1])
pred = np.array([0.6, 0.9, 0.4, 0.8, 0.7, 0.2])
results = utils.compute_performance_metrics(true, pred, threshold=0.5)
for k in ['precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr',
'thresholds', 'auc', 'advantage']:
self.assertIn(k, results)
np.testing.assert_almost_equal(results['accuracy'], 1. / 2.)
np.testing.assert_almost_equal(results['precision'], 2. / (2. + 2.))
np.testing.assert_almost_equal(results['recall'], 2. / (2. + 1.))
def test_prepend_to_keys(self):
"""Test prepending of text to keys of a dictionary."""
mydict = utils.prepend_to_keys(self.mydict, 'test')
for k in mydict:
self.assertTrue(k.startswith('test'))
def test_select_indices(self):
"""Test selecting indices from dictionary with array values."""
mydict = {'a': np.arange(10), 'b': np.linspace(0, 1, 10)}
idx = np.arange(5)
mydictidx = utils.select_indices(mydict, idx)
np.testing.assert_allclose(mydictidx['a'], np.arange(5))
np.testing.assert_allclose(mydictidx['b'], np.linspace(0, 1, 10)[:5])
idx = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) > 0.5
mydictidx = utils.select_indices(mydict, idx)
np.testing.assert_allclose(mydictidx['a'], np.arange(0, 10, 2))
np.testing.assert_allclose(mydictidx['b'], np.linspace(0, 1, 10)[0:10:2])
def test_get_features(self):
"""Test extraction of features."""
for k in [1, 5, 10, 15]:
for add_loss in [True, False]:
feats = utils.get_features(
self.mydict, 'logits', top_k=k, add_loss=add_loss)
k_selected = min(k, 14)
self.assertEqual(feats.shape, (1000, k_selected + int(add_loss)))
def test_subsample_to_balance(self):
"""Test subsampling of two arrays."""
feats = utils.subsample_to_balance(self.mydict, random_state=23)
train = np.sum(self.mydict['is_train'])
test = 1000 - train
n_chosen = min(train, test)
self.assertEqual(feats['logits'].shape, (2 * n_chosen, 14))
self.assertEqual(feats['loss'].shape, (2 * n_chosen,))
self.assertEqual(np.sum(feats['is_train']), n_chosen)
self.assertEqual(np.sum(1 - feats['is_train']), n_chosen)
def test_get_data(self):
"""Test train test split data generation."""
for test_size in [0.2, 0.5, 0.8, 0.55555]:
(x_train, y_train), (x_test, y_test) = utils.get_train_test_split(
self.mydict, add_loss=True, test_size=test_size)
n_test = int(test_size * 1000)
n_train = 1000 - n_test
self.assertEqual(x_train.shape, (n_train, 11))
self.assertEqual(y_train.shape, (n_train,))
self.assertEqual(x_test.shape, (n_test, 11))
self.assertEqual(y_test.shape, (n_test,))
if __name__ == '__main__':
absltest.main()