Remove old API.
PiperOrigin-RevId: 334406920
This commit is contained in:
parent
78d30a0424
commit
bca2baae8d
3 changed files with 0 additions and 1150 deletions
|
@ -1,737 +0,0 @@
|
||||||
# Copyright 2020, The TensorFlow Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Lint as: python3
|
|
||||||
"""Code that runs membership inference attacks based on the model outputs.
|
|
||||||
|
|
||||||
Warning: This file belongs to the old API for membership inference attacks. This
|
|
||||||
file will be removed soon. membership_inference_attack_new.py contains the new
|
|
||||||
API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import collections
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
from typing import Text, Dict, Iterable, Tuple, Union, Any
|
|
||||||
|
|
||||||
from absl import logging
|
|
||||||
import numpy as np
|
|
||||||
from scipy import special
|
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import plotting
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import trained_attack_models
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import utils
|
|
||||||
|
|
||||||
from os import mkdir
|
|
||||||
|
|
||||||
ArrayDict = Dict[Text, np.ndarray]
|
|
||||||
FloatDict = Dict[Text, float]
|
|
||||||
AnyDict = Dict[Text, Any]
|
|
||||||
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
|
||||||
MetricNames = Union[Text, Iterable[Text]]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_vulnerabilities(result: ArrayDict, metrics: MetricNames) -> FloatDict:
|
|
||||||
"""Gets the vulnerabilities according to the chosen metrics for all attacks."""
|
|
||||||
vulns = {}
|
|
||||||
if isinstance(metrics, str):
|
|
||||||
metrics = [metrics]
|
|
||||||
for k in result:
|
|
||||||
for metric in metrics:
|
|
||||||
if k.endswith(metric.lower()) or k.endswith('n_examples'):
|
|
||||||
vulns[k] = float(result[k])
|
|
||||||
return vulns
|
|
||||||
|
|
||||||
|
|
||||||
def _get_maximum_vulnerability(
|
|
||||||
attack_result: FloatDict,
|
|
||||||
metrics: MetricNames,
|
|
||||||
filterby: Text = '') -> Dict[Text, Dict[Text, Union[Text, float]]]:
|
|
||||||
"""Returns the worst vulnerability according to the chosen metrics of all attacks."""
|
|
||||||
vulns = {}
|
|
||||||
if isinstance(metrics, str):
|
|
||||||
metrics = [metrics]
|
|
||||||
for metric in metrics:
|
|
||||||
best_attack_value = -np.inf
|
|
||||||
for k in attack_result:
|
|
||||||
if (k.startswith(filterby.lower()) and k.endswith(metric.lower()) and
|
|
||||||
'train' not in k):
|
|
||||||
if float(attack_result[k]) > best_attack_value:
|
|
||||||
best_attack_value = attack_result[k]
|
|
||||||
best_attacker = k
|
|
||||||
if best_attack_value > -np.inf:
|
|
||||||
newkey = filterby + '-' + metric if filterby else metric
|
|
||||||
vulns[newkey] = {'value': best_attack_value, 'attacker': best_attacker}
|
|
||||||
return vulns
|
|
||||||
|
|
||||||
|
|
||||||
def _get_maximum_class_gap_or_none(result: FloatDict,
|
|
||||||
metrics: MetricNames) -> FloatDict:
|
|
||||||
"""Returns the biggest and smallest vulnerability and the gap across classes."""
|
|
||||||
gaps = {}
|
|
||||||
if isinstance(metrics, str):
|
|
||||||
metrics = [metrics]
|
|
||||||
for metric in metrics:
|
|
||||||
hi = -np.inf
|
|
||||||
lo = np.inf
|
|
||||||
hi_idx, lo_idx = -1, -1
|
|
||||||
for k in result:
|
|
||||||
if (k.startswith('class') and k.endswith(metric.lower()) and
|
|
||||||
'train' not in k):
|
|
||||||
if float(result[k]) > hi:
|
|
||||||
hi = float(result[k])
|
|
||||||
hi_idx = int(re.findall(r'class_(\d+)_', k)[0])
|
|
||||||
if float(result[k]) < lo:
|
|
||||||
lo = float(result[k])
|
|
||||||
lo_idx = int(re.findall(r'class_(\d+)_', k)[0])
|
|
||||||
if lo - hi < np.inf:
|
|
||||||
gaps['max_class_gap_' + metric] = hi - lo
|
|
||||||
gaps[f'class_{hi_idx}_' + metric] = hi
|
|
||||||
gaps[f'class_{lo_idx}_' + metric] = lo
|
|
||||||
gaps['max_vuln_class_' + metric] = hi_idx
|
|
||||||
gaps['min_vuln_class_' + metric] = lo_idx
|
|
||||||
return gaps
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Attacks
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _run_threshold_loss_attack(features: ArrayDict,
|
|
||||||
figure_file_prefix: Text = '',
|
|
||||||
figure_directory: Text = None) -> ArrayDict:
|
|
||||||
"""Runs the threshold attack on the loss."""
|
|
||||||
logging.info('Run threshold attack on loss...')
|
|
||||||
is_train = features['is_train']
|
|
||||||
attack_prefix = 'thresh_loss'
|
|
||||||
tmp_results = utils.compute_performance_metrics(is_train, -features['loss'])
|
|
||||||
if figure_directory is not None:
|
|
||||||
figpath = os.path.join(figure_directory,
|
|
||||||
figure_file_prefix + attack_prefix + '.png')
|
|
||||||
plotting.save_plot(
|
|
||||||
plotting.plot_curve_with_area(
|
|
||||||
tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'),
|
|
||||||
figpath)
|
|
||||||
figpath = os.path.join(figure_directory,
|
|
||||||
figure_file_prefix + attack_prefix + '_hist.png')
|
|
||||||
plotting.save_plot(
|
|
||||||
plotting.plot_histograms(
|
|
||||||
features['loss'][is_train == 1],
|
|
||||||
features['loss'][is_train == 0],
|
|
||||||
xlabel='loss'), figpath)
|
|
||||||
return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
|
|
||||||
|
|
||||||
|
|
||||||
def _run_threshold_attack_maxlogit(features: ArrayDict,
|
|
||||||
figure_file_prefix: Text = '',
|
|
||||||
figure_directory: Text = None) -> ArrayDict:
|
|
||||||
"""Runs the threshold attack on the maximum logit."""
|
|
||||||
is_train = features['is_train']
|
|
||||||
preds = np.max(features['logits'], axis=-1)
|
|
||||||
tmp_results = utils.compute_performance_metrics(is_train, preds)
|
|
||||||
attack_prefix = 'thresh_maxlogit'
|
|
||||||
if figure_directory is not None:
|
|
||||||
figpath = os.path.join(figure_directory,
|
|
||||||
figure_file_prefix + attack_prefix + '.png')
|
|
||||||
plotting.save_plot(
|
|
||||||
plotting.plot_curve_with_area(
|
|
||||||
tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'),
|
|
||||||
figpath)
|
|
||||||
figpath = os.path.join(figure_directory,
|
|
||||||
figure_file_prefix + attack_prefix + '_hist.png')
|
|
||||||
plotting.save_plot(
|
|
||||||
plotting.plot_histograms(
|
|
||||||
preds[is_train == 1], preds[is_train == 0], xlabel='loss'), figpath)
|
|
||||||
return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
|
|
||||||
|
|
||||||
|
|
||||||
def _run_trained_attack(attack_classifier: Text,
|
|
||||||
data: Dataset,
|
|
||||||
attack_prefix: Text,
|
|
||||||
figure_file_prefix: Text = '',
|
|
||||||
figure_directory: Text = None) -> ArrayDict:
|
|
||||||
"""Train a classifier for attack and evaluate it."""
|
|
||||||
# Train the attack classifier
|
|
||||||
(x_train, y_train), (x_test, y_test) = data
|
|
||||||
clf_model = trained_attack_models.choose_model(attack_classifier)
|
|
||||||
clf_model.fit(x_train, y_train)
|
|
||||||
|
|
||||||
# Calculate training set metrics
|
|
||||||
pred_train = clf_model.predict_proba(x_train)[:, clf_model.classes_ == 1]
|
|
||||||
results = utils.prepend_to_keys(
|
|
||||||
utils.compute_performance_metrics(y_train, pred_train),
|
|
||||||
attack_prefix + 'train_')
|
|
||||||
|
|
||||||
# Calculate test set metrics
|
|
||||||
pred_test = clf_model.predict_proba(x_test)[:, clf_model.classes_ == 1]
|
|
||||||
results.update(
|
|
||||||
utils.prepend_to_keys(
|
|
||||||
utils.compute_performance_metrics(y_test, pred_test),
|
|
||||||
attack_prefix + 'test_'))
|
|
||||||
|
|
||||||
if figure_directory is not None:
|
|
||||||
figpath = os.path.join(figure_directory,
|
|
||||||
figure_file_prefix + attack_prefix[:-1] + '.png')
|
|
||||||
plotting.save_plot(
|
|
||||||
plotting.plot_curve_with_area(
|
|
||||||
results[attack_prefix + 'test_fpr'],
|
|
||||||
results[attack_prefix + 'test_tpr'],
|
|
||||||
xlabel='FPR',
|
|
||||||
ylabel='TPR'), figpath)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _run_attacks_and_plot(features: ArrayDict,
|
|
||||||
attacks: Iterable[Text],
|
|
||||||
attack_classifiers: Iterable[Text],
|
|
||||||
balance: bool,
|
|
||||||
test_size: float,
|
|
||||||
random_state: int,
|
|
||||||
figure_file_prefix: Text = '',
|
|
||||||
figure_directory: Text = None) -> ArrayDict:
|
|
||||||
"""Runs the specified attacks on the provided data."""
|
|
||||||
if balance:
|
|
||||||
try:
|
|
||||||
features = utils.subsample_to_balance(features, random_state)
|
|
||||||
except RuntimeError:
|
|
||||||
logging.info('Not enough remaining data for attack: Empty results.')
|
|
||||||
return {}
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
# -------------------- Simple threshold attacks
|
|
||||||
if 'thresh_loss' in attacks:
|
|
||||||
result.update(
|
|
||||||
_run_threshold_loss_attack(features, figure_file_prefix,
|
|
||||||
figure_directory))
|
|
||||||
|
|
||||||
if 'thresh_maxlogit' in attacks:
|
|
||||||
result.update(
|
|
||||||
_run_threshold_attack_maxlogit(features, figure_file_prefix,
|
|
||||||
figure_directory))
|
|
||||||
|
|
||||||
# -------------------- Run learned attacks
|
|
||||||
# TODO(b/157632603): Add a prefix (for example 'trained_') for attacks which
|
|
||||||
# use classifiers to distinguish from threshould attacks.
|
|
||||||
if 'logits' in attacks:
|
|
||||||
data = utils.get_train_test_split(
|
|
||||||
features, add_loss=False, test_size=test_size)
|
|
||||||
for clf in attack_classifiers:
|
|
||||||
logging.info('Train %s on %d logits', clf, data[0][0].shape[1])
|
|
||||||
attack_prefix = f'{clf}_logits_'
|
|
||||||
result.update(
|
|
||||||
_run_trained_attack(clf, data, attack_prefix, figure_file_prefix,
|
|
||||||
figure_directory))
|
|
||||||
|
|
||||||
if 'logits_loss' in attacks:
|
|
||||||
data = utils.get_train_test_split(
|
|
||||||
features, add_loss=True, test_size=test_size)
|
|
||||||
for clf in attack_classifiers:
|
|
||||||
logging.info('Train %s on %d logits + loss', clf, data[0][0].shape[1])
|
|
||||||
attack_prefix = f'{clf}_logits_loss_'
|
|
||||||
result.update(
|
|
||||||
_run_trained_attack(clf, data, attack_prefix, figure_file_prefix,
|
|
||||||
figure_directory))
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def run_attack(loss_train: np.ndarray = None,
|
|
||||||
loss_test: np.ndarray = None,
|
|
||||||
logits_train: np.ndarray = None,
|
|
||||||
logits_test: np.ndarray = None,
|
|
||||||
labels_train: np.ndarray = None,
|
|
||||||
labels_test: np.ndarray = None,
|
|
||||||
attack_classifiers: Iterable[Text] = None,
|
|
||||||
only_misclassified: bool = False,
|
|
||||||
by_class: Union[bool, Iterable[int], int] = False,
|
|
||||||
by_percentile: Union[bool, Iterable[int], int] = False,
|
|
||||||
figure_directory: Text = None,
|
|
||||||
output_directory: Text = None,
|
|
||||||
metric: MetricNames = 'auc',
|
|
||||||
balance: bool = True,
|
|
||||||
test_size: float = 0.2,
|
|
||||||
random_state: int = 0) -> FloatDict:
|
|
||||||
"""Run membership inference attack(s).
|
|
||||||
|
|
||||||
Based only on specific outputs of a machine learning model on some examples
|
|
||||||
used for training (train) and some examples not used for training (test), run
|
|
||||||
membership inference attacks that try to discriminate training from test
|
|
||||||
inputs based only on the model outputs.
|
|
||||||
While all inputs are optional, at least one train/test pair is required to run
|
|
||||||
any attacks (either losses or logits/probabilities).
|
|
||||||
Note that one can equally provide output probabilities instead of logits in
|
|
||||||
the logits_train / logits_test arguments.
|
|
||||||
|
|
||||||
We measure the vulnerability of the model via the area under the ROC-curve
|
|
||||||
(auc) or via max |fpr - tpr| (advantage) of the attack classifier. These
|
|
||||||
measures are very closely related and may look almost indistinguishable.
|
|
||||||
|
|
||||||
This function provides relatively fine grained control and outputs detailed
|
|
||||||
results. For a higher-level wrapper with sane internal default settings and
|
|
||||||
distilled output results, see `run_all_attacks`.
|
|
||||||
|
|
||||||
Via the `figure_directory` argument and the `output_directory` argument more
|
|
||||||
detailed information as well as roc-curve plots can optionally be stored to
|
|
||||||
disk.
|
|
||||||
|
|
||||||
If `loss_train` and `loss_test` are provided we run:
|
|
||||||
- simple threshold attack on the loss
|
|
||||||
|
|
||||||
If `logits_train` and `logits_test` are provided we run:
|
|
||||||
- simple threshold attack on the top logit
|
|
||||||
- if `attack_classifiers` is not None and no losses are provided: train the
|
|
||||||
specified classifiers on the top 10 logits (or all logits if there are
|
|
||||||
less than 10)
|
|
||||||
- if `attack_classifiers` is not None and losses are provided: train the
|
|
||||||
specified classifiers on the top 10 logits (or all logits if there are
|
|
||||||
less than 10) and the loss
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss_train: A 1D array containing the individual scalar losses for examples
|
|
||||||
used during training.
|
|
||||||
loss_test: A 1D array containing the individual scalar losses for examples
|
|
||||||
not used during training.
|
|
||||||
logits_train: A 2D array (n_train, n_classes) of the individual logits or
|
|
||||||
output probabilities of examples used during training.
|
|
||||||
logits_test: A 2D array (n_test, n_classes) of the individual logits or
|
|
||||||
output probabilities of examples not used during training.
|
|
||||||
labels_train: The true labels of the training examples. Labels are only
|
|
||||||
needed when `by_class` is specified (i.e., not False).
|
|
||||||
labels_test: The true labels of the test examples. Labels are only needed
|
|
||||||
when `by_class` is specified (i.e., not False).
|
|
||||||
attack_classifiers: Attack classifiers to train beyond simple thresholding
|
|
||||||
that require training a simple binary ML classifier. This argument is
|
|
||||||
ignored if logits are not provided. Classifiers can be 'lr' for logistic
|
|
||||||
regression, 'mlp' for multi-layered perceptron, 'rf' for random forests,
|
|
||||||
or 'knn' for k-nearest-neighbors. If 'None', don't train classifiers
|
|
||||||
beyond simple thresholding.
|
|
||||||
only_misclassified: Run and evaluate attacks only on misclassified examples.
|
|
||||||
Must specify `labels_train`, `labels_test`, `logits_train` and
|
|
||||||
`logits_test` to use this. If this is True, `by_class` and `by_percentile`
|
|
||||||
are ignored.
|
|
||||||
by_class: This argument determines whether attacks are run on the entire
|
|
||||||
data, or on examples grouped by their class label. If `True`, all attacks
|
|
||||||
are run separately for each class. If `by_class` is a single integer, run
|
|
||||||
attacks for this class only. If `by_class` is an iterable of integers, run
|
|
||||||
all attacks for each of the specified class labels separately. Only used
|
|
||||||
if `labels_train` and `labels_test` are specified. If `by_class` is
|
|
||||||
specified (not False), `by_percentile` is ignored. Ignored if
|
|
||||||
`only_misclassified` is True.
|
|
||||||
by_percentile: This argument determines whether attacks are run on the
|
|
||||||
entire data, or separately for examples where the most likely class
|
|
||||||
prediction is within a given percentile of all maximum predicitons. If
|
|
||||||
`True`, all attacks are run separately for the examples with max
|
|
||||||
probabilities within the ten deciles. If `by_precentile` is a single int
|
|
||||||
between 0 and 100, run attacks only for examples with confidence within
|
|
||||||
this percentile. If `by_percentile` is an iterable of ints between 0 and
|
|
||||||
100, run all attacks for each of the specified percentiles separately.
|
|
||||||
Ignored if `by_class` is specified. Ignored if `logits_train` and
|
|
||||||
`logits_test` are not specified. Ignored if `only_misclassified` is True.
|
|
||||||
figure_directory: Where to store ROC-curve plots and histograms. If `None`,
|
|
||||||
don't create plots.
|
|
||||||
output_directory: Where to store detailed result data for all run attacks.
|
|
||||||
If `None`, don't store detailed result data.
|
|
||||||
metric: Available vulnerability metrics are 'auc' or 'advantage' for the
|
|
||||||
area under the ROC curve or the advantage (max |tpr - fpr|). Specify
|
|
||||||
either one of them or both.
|
|
||||||
balance: Whether to use the same number of train and test samples (by
|
|
||||||
randomly subsampling whichever happens to be larger).
|
|
||||||
test_size: The fraction of the input data to use for the evaluation of
|
|
||||||
trained ML attacks. This argument is ignored, if either attack_classifiers
|
|
||||||
is None, or no logits are provided.
|
|
||||||
random_state: Random seed for reproducibility. Only used if attack models
|
|
||||||
are trained.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
results: Dictionary with the chosen vulnerability metric(s) for all ran
|
|
||||||
attacks.
|
|
||||||
"""
|
|
||||||
print(
|
|
||||||
'Deprecation warning: function run_attack is '
|
|
||||||
'deprecated and will be removed soon. '
|
|
||||||
'Please use membership_inference_attack_new.run_attacks'
|
|
||||||
)
|
|
||||||
attacks = []
|
|
||||||
features = {}
|
|
||||||
# ---------- Check available data ----------
|
|
||||||
if ((loss_train is None or loss_test is None) and
|
|
||||||
(logits_train is None or logits_test is None)):
|
|
||||||
raise ValueError(
|
|
||||||
'Need at least train and test for loss or train and test for logits.')
|
|
||||||
|
|
||||||
# ---------- If losses are provided ----------
|
|
||||||
if loss_train is not None and loss_test is not None:
|
|
||||||
if loss_train.ndim != 1 or loss_test.ndim != 1:
|
|
||||||
raise ValueError('Losses must be 1D arrays.')
|
|
||||||
features['is_train'] = np.concatenate(
|
|
||||||
(np.ones(len(loss_train)), np.zeros(len(loss_test))),
|
|
||||||
axis=0).astype(int)
|
|
||||||
features['loss'] = np.concatenate((loss_train.ravel(), loss_test.ravel()),
|
|
||||||
axis=0)
|
|
||||||
attacks.append('thresh_loss')
|
|
||||||
|
|
||||||
# ---------- If logits are provided ----------
|
|
||||||
if logits_train is not None and logits_test is not None:
|
|
||||||
assert logits_train.ndim == 2 and logits_test.ndim == 2, \
|
|
||||||
'Logits must be 2D arrays.'
|
|
||||||
assert logits_train.shape[1] == logits_test.shape[1], \
|
|
||||||
'Train and test logits must agree along axis 1 (number of classes).'
|
|
||||||
if 'is_train' in features:
|
|
||||||
assert (loss_train.shape[0] == logits_train.shape[0] and
|
|
||||||
loss_test.shape[0] == logits_test.shape[0]), \
|
|
||||||
'Number of examples must match between loss and logits.'
|
|
||||||
else:
|
|
||||||
features['is_train'] = np.concatenate(
|
|
||||||
(np.ones(logits_train.shape[0]), np.zeros(logits_test.shape[0])),
|
|
||||||
axis=0).astype(int)
|
|
||||||
attacks.append('thresh_maxlogit')
|
|
||||||
features['logits'] = np.concatenate((logits_train, logits_test), axis=0)
|
|
||||||
if attack_classifiers:
|
|
||||||
attacks.append('logits')
|
|
||||||
if 'loss' in features:
|
|
||||||
attacks.append('logits_loss')
|
|
||||||
|
|
||||||
# ---------- If labels are provided ----------
|
|
||||||
if labels_train is not None and labels_test is not None:
|
|
||||||
if labels_train.ndim != 1 or labels_test.ndim != 1:
|
|
||||||
raise ValueError('Labels must be 1D arrays.')
|
|
||||||
if 'loss' in features:
|
|
||||||
assert (loss_train.shape[0] == labels_train.shape[0] and
|
|
||||||
loss_test.shape[0] == labels_test.shape[0]), \
|
|
||||||
'Number of examples must match between loss and labels.'
|
|
||||||
else:
|
|
||||||
assert (logits_train.shape[0] == labels_train.shape[0] and
|
|
||||||
logits_test.shape[0] == labels_test.shape[0]), \
|
|
||||||
'Number of examples must match between logits and labels.'
|
|
||||||
features['label'] = np.concatenate((labels_train, labels_test), axis=0)
|
|
||||||
|
|
||||||
# ---------- Data subsampling or filtering ----------
|
|
||||||
filtertype = None
|
|
||||||
filtervals = [None]
|
|
||||||
if only_misclassified:
|
|
||||||
if (labels_train is None or labels_test is None or logits_train is None or
|
|
||||||
logits_test is None):
|
|
||||||
raise ValueError('Must specify labels_train, labels_test, logits_train, '
|
|
||||||
'and logits_test for the only_misclassified option.')
|
|
||||||
filtertype = 'misclassified'
|
|
||||||
elif by_class:
|
|
||||||
if labels_train is None or labels_test is None:
|
|
||||||
raise ValueError('Must specify labels_train and labels_test when using '
|
|
||||||
'the by_class option.')
|
|
||||||
if isinstance(by_class, bool):
|
|
||||||
filtervals = list(set(labels_train) | set(labels_test))
|
|
||||||
elif isinstance(by_class, int):
|
|
||||||
filtervals = [by_class]
|
|
||||||
elif isinstance(by_class, collections.Iterable):
|
|
||||||
filtervals = list(by_class)
|
|
||||||
filtertype = 'class'
|
|
||||||
elif by_percentile:
|
|
||||||
if logits_train is None or logits_test is None:
|
|
||||||
raise ValueError('Must specify logits_train and logits_test when using '
|
|
||||||
'the by_percentile option.')
|
|
||||||
if isinstance(by_percentile, bool):
|
|
||||||
filtervals = list(range(10, 101, 10))
|
|
||||||
elif isinstance(by_percentile, int):
|
|
||||||
filtervals = [by_percentile]
|
|
||||||
elif isinstance(by_percentile, collections.Iterable):
|
|
||||||
filtervals = [int(percentile) for percentile in by_percentile]
|
|
||||||
filtertype = 'percentile'
|
|
||||||
|
|
||||||
# ---------- Need to create figure directory? ----------
|
|
||||||
if figure_directory is not None:
|
|
||||||
mkdir(figure_directory)
|
|
||||||
|
|
||||||
# ---------- Actually run attacks and plot if required ----------
|
|
||||||
logging.info('Selecting %s with values %s', filtertype, filtervals)
|
|
||||||
num = None
|
|
||||||
result = {}
|
|
||||||
for filterval in filtervals:
|
|
||||||
if filtertype is None:
|
|
||||||
tmp_features = features
|
|
||||||
elif filtertype == 'misclassified':
|
|
||||||
idx = features['label'] != np.argmax(features['logits'], axis=-1)
|
|
||||||
tmp_features = utils.select_indices(features, idx)
|
|
||||||
num = np.sum(idx)
|
|
||||||
elif filtertype == 'class':
|
|
||||||
idx = features['label'] == filterval
|
|
||||||
tmp_features = utils.select_indices(features, idx)
|
|
||||||
num = np.sum(idx)
|
|
||||||
elif filtertype == 'percentile':
|
|
||||||
certainty = np.max(special.softmax(features['logits'], axis=-1), axis=-1)
|
|
||||||
idx = certainty <= np.percentile(certainty, filterval)
|
|
||||||
tmp_features = utils.select_indices(features, idx)
|
|
||||||
|
|
||||||
prefix = f'{filtertype}_' if filtertype is not None else ''
|
|
||||||
prefix += f'{filterval}_' if filterval is not None else ''
|
|
||||||
tmp_result = _run_attacks_and_plot(tmp_features, attacks,
|
|
||||||
attack_classifiers, balance, test_size,
|
|
||||||
random_state, prefix, figure_directory)
|
|
||||||
if num is not None:
|
|
||||||
tmp_result['n_examples'] = float(num)
|
|
||||||
if tmp_result:
|
|
||||||
result.update(utils.prepend_to_keys(tmp_result, prefix))
|
|
||||||
|
|
||||||
# ---------- Store data ----------
|
|
||||||
if output_directory is not None:
|
|
||||||
mkdir(output_directory)
|
|
||||||
resultpath = os.path.join(output_directory, 'attack_results.npz')
|
|
||||||
logging.info('Store aggregate results at %s.', resultpath)
|
|
||||||
with open(resultpath, 'wb') as fp:
|
|
||||||
io_buffer = io.BytesIO()
|
|
||||||
np.savez(io_buffer, **result)
|
|
||||||
fp.write(io_buffer.getvalue())
|
|
||||||
|
|
||||||
return _get_vulnerabilities(result, metric)
|
|
||||||
|
|
||||||
|
|
||||||
def run_all_attacks(loss_train: np.ndarray = None,
|
|
||||||
loss_test: np.ndarray = None,
|
|
||||||
logits_train: np.ndarray = None,
|
|
||||||
logits_test: np.ndarray = None,
|
|
||||||
labels_train: np.ndarray = None,
|
|
||||||
labels_test: np.ndarray = None,
|
|
||||||
attack_classifiers: Iterable[Text] = ('lr', 'mlp', 'rf',
|
|
||||||
'knn'),
|
|
||||||
decimals: Union[int, None] = 4) -> FloatDict:
|
|
||||||
"""Runs all possible membership inference attacks.
|
|
||||||
|
|
||||||
Check 'run_attack' for detailed information of how attacks are performed
|
|
||||||
and evaluated.
|
|
||||||
|
|
||||||
This function internally chooses sane default settings for all attacks and
|
|
||||||
returns all possible output combinations.
|
|
||||||
For fine grained control and partial attacks, please see `run_attack`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss_train: A 1D array containing the individual scalar losses for examples
|
|
||||||
used during training.
|
|
||||||
loss_test: A 1D array containing the individual scalar losses for examples
|
|
||||||
not used during training.
|
|
||||||
logits_train: A 2D array (n_train, n_classes) of the individual logits or
|
|
||||||
output probabilities of examples used during training.
|
|
||||||
logits_test: A 2D array (n_test, n_classes) of the individual logits or
|
|
||||||
output probabilities of examples not used during training.
|
|
||||||
labels_train: The true labels of the training examples. Labels are only
|
|
||||||
needed when `by_class` is specified (i.e., not False).
|
|
||||||
labels_test: The true labels of the test examples. Labels are only needed
|
|
||||||
when `by_class` is specified (i.e., not False).
|
|
||||||
attack_classifiers: Which binary classifiers to train (in addition to simple
|
|
||||||
threshold attacks). This can include 'lr' (logistic regression), 'mlp'
|
|
||||||
(multi-layered perceptron), 'rf' (random forests), 'knn' (k-nearest
|
|
||||||
neighbors), which will be trained with cross validation to determine good
|
|
||||||
hyperparameters.
|
|
||||||
decimals: Round all float results to this number of decimals. If decimals is
|
|
||||||
None, don't round.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
result: dictionary with all attack results
|
|
||||||
"""
|
|
||||||
print(
|
|
||||||
'Deprecation warning: function run_all_attacks is '
|
|
||||||
'deprecated and will be removed soon. '
|
|
||||||
'Please use membership_inference_attack_new.run_attacks'
|
|
||||||
)
|
|
||||||
metrics = ['auc', 'advantage']
|
|
||||||
|
|
||||||
# Entire data
|
|
||||||
result = run_attack(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
logits_train,
|
|
||||||
logits_test,
|
|
||||||
attack_classifiers=attack_classifiers,
|
|
||||||
metric=metrics)
|
|
||||||
result = utils.prepend_to_keys(result, 'all_')
|
|
||||||
|
|
||||||
# Misclassified examples
|
|
||||||
if (labels_train is not None and labels_test is not None and
|
|
||||||
logits_train is not None and logits_test is not None):
|
|
||||||
result.update(
|
|
||||||
run_attack(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
logits_train,
|
|
||||||
logits_test,
|
|
||||||
labels_train,
|
|
||||||
labels_test,
|
|
||||||
attack_classifiers=attack_classifiers,
|
|
||||||
only_misclassified=True,
|
|
||||||
metric=metrics))
|
|
||||||
|
|
||||||
# Split per class
|
|
||||||
if labels_train is not None and labels_test is not None:
|
|
||||||
result.update(
|
|
||||||
run_attack(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
logits_train,
|
|
||||||
logits_test,
|
|
||||||
labels_train,
|
|
||||||
labels_test,
|
|
||||||
by_class=True,
|
|
||||||
attack_classifiers=attack_classifiers,
|
|
||||||
metric=metrics))
|
|
||||||
|
|
||||||
# Different deciles
|
|
||||||
if logits_train is not None and logits_test is not None:
|
|
||||||
result.update(
|
|
||||||
run_attack(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
logits_train,
|
|
||||||
logits_test,
|
|
||||||
by_percentile=True,
|
|
||||||
attack_classifiers=attack_classifiers,
|
|
||||||
metric=metrics))
|
|
||||||
|
|
||||||
if decimals is not None:
|
|
||||||
result = {k: round(v, decimals) for k, v in result.items()}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def run_all_attacks_and_create_summary(
|
|
||||||
loss_train: np.ndarray = None,
|
|
||||||
loss_test: np.ndarray = None,
|
|
||||||
logits_train: np.ndarray = None,
|
|
||||||
logits_test: np.ndarray = None,
|
|
||||||
labels_train: np.ndarray = None,
|
|
||||||
labels_test: np.ndarray = None,
|
|
||||||
return_dict: bool = True,
|
|
||||||
decimals: Union[int, None] = 4) -> Union[Text, Tuple[Text, AnyDict]]:
|
|
||||||
"""Runs all possible membership inference attack(s) and distill results.
|
|
||||||
|
|
||||||
Check 'run_attack' for detailed information of how attacks are performed
|
|
||||||
and evaluated.
|
|
||||||
|
|
||||||
This function internally chooses sane default settings for all attacks and
|
|
||||||
returns all possible output combinations.
|
|
||||||
For fine grained control and partial attacks, please see `run_attack`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss_train: A 1D array containing the individual scalar losses for examples
|
|
||||||
used during training.
|
|
||||||
loss_test: A 1D array containing the individual scalar losses for examples
|
|
||||||
not used during training.
|
|
||||||
logits_train: A 2D array (n_train, n_classes) of the individual logits or
|
|
||||||
output probabilities of examples used during training.
|
|
||||||
logits_test: A 2D array (n_test, n_classes) of the individual logits or
|
|
||||||
output probabilities of examples not used during training.
|
|
||||||
labels_train: The true labels of the training examples. Labels are only
|
|
||||||
needed when `by_class` is specified (i.e., not False).
|
|
||||||
labels_test: The true labels of the test examples. Labels are only needed
|
|
||||||
when `by_class` is specified (i.e., not False).
|
|
||||||
return_dict: Whether to also return a dictionary with the results summarized
|
|
||||||
in the summary string.
|
|
||||||
decimals: Round all float results to this number of decimals. If decimals is
|
|
||||||
None, don't round.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
summarystring: A string with natural language summary of the attacks. In the
|
|
||||||
summary string printed numbers will be rounded to `decimal` decimals if
|
|
||||||
provided, otherwise will round to 3 diits by default for readability.
|
|
||||||
result: a dictionary with all the distilled attack information summarized
|
|
||||||
in the summarystring
|
|
||||||
"""
|
|
||||||
print(
|
|
||||||
'Deprecation warning: function run_all_attacks_and_create_summary is '
|
|
||||||
'deprecated and will be removed soon. '
|
|
||||||
'Please use membership_inference_attack_new.run_attacks'
|
|
||||||
)
|
|
||||||
summary = []
|
|
||||||
metrics = ['auc', 'advantage']
|
|
||||||
attack_classifiers = ['lr', 'knn']
|
|
||||||
results = run_all_attacks(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
logits_train,
|
|
||||||
logits_test,
|
|
||||||
labels_train,
|
|
||||||
labels_test,
|
|
||||||
attack_classifiers=attack_classifiers,
|
|
||||||
decimals=None)
|
|
||||||
output = _get_maximum_vulnerability(results, metrics, filterby='all')
|
|
||||||
|
|
||||||
if decimals is not None:
|
|
||||||
strdec = decimals
|
|
||||||
else:
|
|
||||||
strdec = 4
|
|
||||||
|
|
||||||
for metric in metrics:
|
|
||||||
summary.append(f'========== {metric.upper()} ==========')
|
|
||||||
best_value = output['all-' + metric]['value']
|
|
||||||
best_attacker = output['all-' + metric]['attacker']
|
|
||||||
summary.append(f'The best attack ({best_attacker}) achieved an {metric} of '
|
|
||||||
f'{best_value:.{strdec}f}.')
|
|
||||||
summary.append('')
|
|
||||||
|
|
||||||
classgap = _get_maximum_class_gap_or_none(results, metrics)
|
|
||||||
if classgap:
|
|
||||||
output.update(classgap)
|
|
||||||
for metric in metrics:
|
|
||||||
summary.append(f'========== {metric.upper()} per class ==========')
|
|
||||||
hi_idx = output[f'max_vuln_class_{metric}']
|
|
||||||
lo_idx = output[f'min_vuln_class_{metric}']
|
|
||||||
hi = output[f'class_{hi_idx}_{metric}']
|
|
||||||
lo = output[f'class_{lo_idx}_{metric}']
|
|
||||||
gap = output[f'max_class_gap_{metric}']
|
|
||||||
summary.append(f'The most vulnerable class {hi_idx} has {metric} of '
|
|
||||||
f'{hi:.{strdec}f}.')
|
|
||||||
summary.append(f'The least vulnerable class {lo_idx} has {metric} of '
|
|
||||||
f'{lo:.{strdec}f}.')
|
|
||||||
summary.append(f'=> The maximum gap between class vulnerabilities is '
|
|
||||||
f'{gap:.{strdec}f}.')
|
|
||||||
summary.append('')
|
|
||||||
|
|
||||||
misclassified = _get_maximum_vulnerability(
|
|
||||||
results, metrics, filterby='misclassified')
|
|
||||||
if misclassified:
|
|
||||||
for metric in metrics:
|
|
||||||
best_value = misclassified['misclassified-' + metric]['value']
|
|
||||||
best_attacker = misclassified['misclassified-' + metric]['attacker']
|
|
||||||
summary.append(f'========== {metric.upper()} for misclassified '
|
|
||||||
'==========')
|
|
||||||
summary.append('Among misclassified examples, the best attack '
|
|
||||||
f'({best_attacker}) achieved an {metric} of '
|
|
||||||
f'{best_value:.{strdec}f}.')
|
|
||||||
summary.append('')
|
|
||||||
output.update(misclassified)
|
|
||||||
|
|
||||||
n_examples = {k: v for k, v in results.items() if k.endswith('n_examples')}
|
|
||||||
if n_examples:
|
|
||||||
output.update(n_examples)
|
|
||||||
|
|
||||||
# Flatten remaining dicts in output
|
|
||||||
fresh_output = {}
|
|
||||||
for k, v in output.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
if k.startswith('all'):
|
|
||||||
fresh_output[k[4:]] = v['value']
|
|
||||||
fresh_output['best_attacker_' + k[4:]] = v['attacker']
|
|
||||||
else:
|
|
||||||
fresh_output[k] = v
|
|
||||||
output = fresh_output
|
|
||||||
|
|
||||||
if decimals is not None:
|
|
||||||
for k, v in output.items():
|
|
||||||
if isinstance(v, float):
|
|
||||||
output[k] = round(v, decimals)
|
|
||||||
|
|
||||||
summary = '\n'.join(summary)
|
|
||||||
if return_dict:
|
|
||||||
return summary, output
|
|
||||||
else:
|
|
||||||
return summary
|
|
|
@ -1,307 +0,0 @@
|
||||||
# Copyright 2020, The TensorFlow Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Lint as: python3
|
|
||||||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
|
|
||||||
from absl.testing import absltest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
|
||||||
|
|
||||||
|
|
||||||
def get_result_dict():
|
|
||||||
"""Get an example result dictionary."""
|
|
||||||
return {
|
|
||||||
'test_n_examples': np.ones(1),
|
|
||||||
'test_examples': np.zeros(1),
|
|
||||||
'test_auc': np.ones(1),
|
|
||||||
'test_advantage': np.ones(1),
|
|
||||||
'all_0-metric': np.array([1]),
|
|
||||||
'all_1-metric': np.array([2]),
|
|
||||||
'test_2-metric': np.array([3]),
|
|
||||||
'test_score': np.array([4]),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_test_inputs():
|
|
||||||
"""Get example inputs for attacks."""
|
|
||||||
n_train = n_test = 500
|
|
||||||
rng = np.random.RandomState(4)
|
|
||||||
loss_train = rng.randn(n_train) - 0.4
|
|
||||||
loss_test = rng.randn(n_test) + 0.4
|
|
||||||
logits_train = rng.randn(n_train, 5) + 0.2
|
|
||||||
logits_test = rng.randn(n_test, 5) - 0.2
|
|
||||||
labels_train = np.array([i % 5 for i in range(n_train)])
|
|
||||||
labels_test = np.array([(3 * i) % 5 for i in range(n_test)])
|
|
||||||
return (loss_train, loss_test, logits_train, logits_test,
|
|
||||||
labels_train, labels_test)
|
|
||||||
|
|
||||||
|
|
||||||
class GetVulnerabilityTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_get_vulnerabilities(self):
|
|
||||||
"""Test extraction of vulnerability scores."""
|
|
||||||
testdict = get_result_dict()
|
|
||||||
for key in ['auc', 'advantage']:
|
|
||||||
res = mia._get_vulnerabilities(testdict, key)
|
|
||||||
self.assertLen(res, 2)
|
|
||||||
self.assertEqual(res[f'test_{key}'], 1)
|
|
||||||
self.assertEqual(res['test_n_examples'], 1)
|
|
||||||
|
|
||||||
res = mia._get_vulnerabilities(testdict, ['auc', 'advantage'])
|
|
||||||
self.assertLen(res, 3)
|
|
||||||
self.assertEqual(res['test_auc'], 1)
|
|
||||||
self.assertEqual(res['test_advantage'], 1)
|
|
||||||
self.assertEqual(res['test_n_examples'], 1)
|
|
||||||
|
|
||||||
|
|
||||||
class GetMaximumVulnerabilityTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_get_maximum_vulnerability(self):
|
|
||||||
"""Test extraction of maximum vulnerability score."""
|
|
||||||
testdict = get_result_dict()
|
|
||||||
for i in range(3):
|
|
||||||
key = f'{i}-metric'
|
|
||||||
res = mia._get_maximum_vulnerability(testdict, key)
|
|
||||||
self.assertLen(res, 1)
|
|
||||||
self.assertEqual(res[key]['value'], i + 1)
|
|
||||||
if i < 2:
|
|
||||||
self.assertEqual(res[key]['attacker'], f'all_{i}-metric')
|
|
||||||
else:
|
|
||||||
self.assertEqual(res[key]['attacker'], 'test_2-metric')
|
|
||||||
|
|
||||||
res = mia._get_maximum_vulnerability(testdict, 'metric')
|
|
||||||
self.assertLen(res, 1)
|
|
||||||
self.assertEqual(res['metric']['value'], 3)
|
|
||||||
|
|
||||||
res = mia._get_maximum_vulnerability(testdict, ['metric'],
|
|
||||||
filterby='all')
|
|
||||||
self.assertLen(res, 1)
|
|
||||||
self.assertEqual(res['all-metric']['value'], 2)
|
|
||||||
|
|
||||||
res = mia._get_maximum_vulnerability(testdict, ['metric', 'score'])
|
|
||||||
self.assertLen(res, 2)
|
|
||||||
self.assertEqual(res['metric']['value'], 3)
|
|
||||||
self.assertEqual(res['score']['value'], 4)
|
|
||||||
self.assertEqual(res['score']['attacker'], 'test_score')
|
|
||||||
|
|
||||||
|
|
||||||
class ThresholdAttackLossTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_threshold_attack_loss(self):
|
|
||||||
"""Test simple threshold attack on loss."""
|
|
||||||
features = {
|
|
||||||
'loss': np.zeros(10),
|
|
||||||
'is_train': np.concatenate((np.zeros(5), np.ones(5))),
|
|
||||||
}
|
|
||||||
res = mia._run_threshold_loss_attack(features)
|
|
||||||
for k in res:
|
|
||||||
self.assertStartsWith(k, 'thresh_loss')
|
|
||||||
self.assertEqual(res['thresh_loss_auc'], 0.5)
|
|
||||||
self.assertEqual(res['thresh_loss_advantage'], 0.0)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(4)
|
|
||||||
n_train = 1000
|
|
||||||
n_test = 500
|
|
||||||
loss_train = rng.randn(n_train) - 0.4
|
|
||||||
loss_test = rng.randn(n_test) + 0.4
|
|
||||||
features = {
|
|
||||||
'loss': np.concatenate((loss_train, loss_test)),
|
|
||||||
'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))),
|
|
||||||
}
|
|
||||||
res = mia._run_threshold_loss_attack(features)
|
|
||||||
self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75)
|
|
||||||
self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35)
|
|
||||||
|
|
||||||
|
|
||||||
class ThresholdAttackMaxlogitTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_threshold_attack_maxlogits(self):
|
|
||||||
"""Test simple threshold attack on maximum logit."""
|
|
||||||
features = {
|
|
||||||
'logits': np.eye(10, 14),
|
|
||||||
'is_train': np.concatenate((np.zeros(5), np.ones(5))),
|
|
||||||
}
|
|
||||||
res = mia._run_threshold_attack_maxlogit(features)
|
|
||||||
for k in res:
|
|
||||||
self.assertStartsWith(k, 'thresh_maxlogit')
|
|
||||||
self.assertEqual(res['thresh_maxlogit_auc'], 0.5)
|
|
||||||
self.assertEqual(res['thresh_maxlogit_advantage'], 0.0)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(4)
|
|
||||||
n_train = 1000
|
|
||||||
n_test = 500
|
|
||||||
logits_train = rng.randn(n_train, 12) + 0.2
|
|
||||||
logits_test = rng.randn(n_test, 12) - 0.2
|
|
||||||
features = {
|
|
||||||
'logits': np.concatenate((logits_train, logits_test), axis=0),
|
|
||||||
'is_train': np.concatenate((np.ones(n_train), np.zeros(n_test))),
|
|
||||||
}
|
|
||||||
res = mia._run_threshold_attack_maxlogit(features)
|
|
||||||
self.assertBetween(res['thresh_maxlogit_auc'], 0.7, 0.75)
|
|
||||||
self.assertBetween(res['thresh_maxlogit_advantage'], 0.3, 0.35)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttackTrivialTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_trained_attack(self):
|
|
||||||
"""Test trained attacks."""
|
|
||||||
# Trivially easy problem
|
|
||||||
x_train, x_test = np.ones((500, 3)), np.ones((20, 3))
|
|
||||||
x_train[:200] *= -1
|
|
||||||
x_test[:8] *= -1
|
|
||||||
y_train, y_test = np.ones(500).astype(int), np.ones(20).astype(int)
|
|
||||||
y_train[:200] = 0
|
|
||||||
y_test[:8] = 0
|
|
||||||
data = (x_train, y_train), (x_test, y_test)
|
|
||||||
for clf in ['lr', 'rf', 'mlp', 'knn']:
|
|
||||||
res = mia._run_trained_attack(clf, data, attack_prefix='a-')
|
|
||||||
self.assertEqual(res['a-train_auc'], 1)
|
|
||||||
self.assertEqual(res['a-test_auc'], 1)
|
|
||||||
self.assertEqual(res['a-train_advantage'], 1)
|
|
||||||
self.assertEqual(res['a-test_advantage'], 1)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttackRandomFeaturesTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_trained_attack(self):
|
|
||||||
"""Test trained attacks."""
|
|
||||||
# Random labels and features
|
|
||||||
rng = np.random.RandomState(4)
|
|
||||||
x_train, x_test = rng.randn(500, 3), rng.randn(500, 3)
|
|
||||||
y_train = rng.binomial(1, 0.5, size=(500,))
|
|
||||||
y_test = rng.binomial(1, 0.5, size=(500,))
|
|
||||||
data = (x_train, y_train), (x_test, y_test)
|
|
||||||
for clf in ['lr', 'rf', 'mlp', 'knn']:
|
|
||||||
res = mia._run_trained_attack(clf, data, attack_prefix='a-')
|
|
||||||
self.assertBetween(res['a-train_auc'], 0.5, 1.)
|
|
||||||
self.assertBetween(res['a-test_auc'], 0.4, 0.6)
|
|
||||||
self.assertBetween(res['a-train_advantage'], 0., 1.0)
|
|
||||||
self.assertBetween(res['a-test_advantage'], 0., 0.2)
|
|
||||||
|
|
||||||
|
|
||||||
class AttackLossesTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_attack(self):
|
|
||||||
"""Test individual attack function."""
|
|
||||||
# losses only, both metrics
|
|
||||||
loss_train, loss_test, _, _, _, _ = get_test_inputs()
|
|
||||||
res = mia.run_attack(loss_train, loss_test, metric=('auc', 'advantage'))
|
|
||||||
self.assertBetween(res['thresh_loss_auc'], 0.7, 0.75)
|
|
||||||
self.assertBetween(res['thresh_loss_advantage'], 0.3, 0.35)
|
|
||||||
|
|
||||||
|
|
||||||
class AttackLossesLogitsTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_attack(self):
|
|
||||||
"""Test individual attack function."""
|
|
||||||
# losses and logits, two classifiers, single metric
|
|
||||||
loss_train, loss_test, logits_train, logits_test, _, _ = get_test_inputs()
|
|
||||||
res = mia.run_attack(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
logits_train,
|
|
||||||
logits_test,
|
|
||||||
attack_classifiers=('rf', 'knn'),
|
|
||||||
metric='auc')
|
|
||||||
self.assertBetween(res['rf_logits_test_auc'], 0.7, 0.9)
|
|
||||||
self.assertBetween(res['knn_logits_test_auc'], 0.7, 0.9)
|
|
||||||
self.assertBetween(res['rf_logits_loss_test_auc'], 0.7, 0.9)
|
|
||||||
self.assertBetween(res['knn_logits_loss_test_auc'], 0.7, 0.9)
|
|
||||||
|
|
||||||
|
|
||||||
class AttackLossesLabelsByClassTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_attack(self):
|
|
||||||
# losses and labels, single metric, split by class
|
|
||||||
loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs()
|
|
||||||
n_train = loss_train.shape[0]
|
|
||||||
n_test = loss_test.shape[0]
|
|
||||||
res = mia.run_attack(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
labels_train=labels_train,
|
|
||||||
labels_test=labels_test,
|
|
||||||
by_class=True,
|
|
||||||
metric='auc')
|
|
||||||
self.assertLen(res, 10)
|
|
||||||
for k in res:
|
|
||||||
self.assertStartsWith(k, 'class_')
|
|
||||||
if k.endswith('n_examples'):
|
|
||||||
self.assertEqual(int(res[k]), (n_train + n_test) // 5)
|
|
||||||
else:
|
|
||||||
self.assertBetween(res[k], 0.65, 0.75)
|
|
||||||
|
|
||||||
|
|
||||||
class AttackLossesLabelsSingleClassTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_attack(self):
|
|
||||||
# losses and labels, both metrics, single class
|
|
||||||
loss_train, loss_test, _, _, labels_train, labels_test = get_test_inputs()
|
|
||||||
n_train = loss_train.shape[0]
|
|
||||||
n_test = loss_test.shape[0]
|
|
||||||
res = mia.run_attack(
|
|
||||||
loss_train,
|
|
||||||
loss_test,
|
|
||||||
labels_train=labels_train,
|
|
||||||
labels_test=labels_test,
|
|
||||||
by_class=2,
|
|
||||||
metric=('auc', 'advantage'))
|
|
||||||
self.assertLen(res, 3)
|
|
||||||
for k in res:
|
|
||||||
self.assertStartsWith(k, 'class_2')
|
|
||||||
if k.endswith('n_examples'):
|
|
||||||
self.assertEqual(int(res[k]), (n_train + n_test) // 5)
|
|
||||||
elif k.endswith('advantage'):
|
|
||||||
self.assertBetween(res[k], 0.3, 0.5)
|
|
||||||
elif k.endswith('auc'):
|
|
||||||
self.assertBetween(res[k], 0.7, 0.75)
|
|
||||||
|
|
||||||
|
|
||||||
class AttackLogitsLabelsMisclassifiedTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_attack(self):
|
|
||||||
# logits and labels, single metric, single classifier, misclassified only
|
|
||||||
(_, _, logits_train, logits_test,
|
|
||||||
labels_train, labels_test) = get_test_inputs()
|
|
||||||
res = mia.run_attack(
|
|
||||||
logits_train=logits_train,
|
|
||||||
logits_test=logits_test,
|
|
||||||
labels_train=labels_train,
|
|
||||||
labels_test=labels_test,
|
|
||||||
only_misclassified=True,
|
|
||||||
attack_classifiers=('lr',),
|
|
||||||
metric='advantage')
|
|
||||||
self.assertBetween(res['misclassified_lr_logits_test_advantage'], 0.3, 0.8)
|
|
||||||
self.assertEqual(res['misclassified_n_examples'], 802)
|
|
||||||
|
|
||||||
|
|
||||||
class AttackLogitsByPrecentileTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_attack(self):
|
|
||||||
# only logits, single metric, no classifiers, split by deciles
|
|
||||||
_, _, logits_train, logits_test, _, _ = get_test_inputs()
|
|
||||||
res = mia.run_attack(
|
|
||||||
logits_train=logits_train,
|
|
||||||
logits_test=logits_test,
|
|
||||||
by_percentile=True,
|
|
||||||
metric='auc')
|
|
||||||
for k in res:
|
|
||||||
self.assertStartsWith(k, 'percentile')
|
|
||||||
self.assertBetween(res[k], 0.60, 0.75)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
absltest.main()
|
|
|
@ -1,106 +0,0 @@
|
||||||
# Copyright 2020, The TensorFlow Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Lint as: python3
|
|
||||||
r"""A collection of sklearn models for binary classification.
|
|
||||||
|
|
||||||
This module contains some sklearn pipelines for finding models for binary
|
|
||||||
classification from a variable number of numerical input features.
|
|
||||||
These models are used to train binary classifiers for membership inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Text
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from sklearn import ensemble
|
|
||||||
from sklearn import linear_model
|
|
||||||
from sklearn import model_selection
|
|
||||||
from sklearn import neighbors
|
|
||||||
from sklearn import neural_network
|
|
||||||
|
|
||||||
|
|
||||||
def choose_model(attack_classifier: Text):
|
|
||||||
"""Choose a model based on a string classifier."""
|
|
||||||
if attack_classifier == 'lr':
|
|
||||||
return logistic_regression()
|
|
||||||
elif attack_classifier == 'mlp':
|
|
||||||
return mlp()
|
|
||||||
elif attack_classifier == 'rf':
|
|
||||||
return random_forest()
|
|
||||||
elif attack_classifier == 'knn':
|
|
||||||
return knn()
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unknown attack classifier {attack_classifier}.')
|
|
||||||
|
|
||||||
|
|
||||||
def logistic_regression(verbose: int = 0, n_jobs: int = 1):
|
|
||||||
"""Setup a logistic regression pipeline with cross-validation."""
|
|
||||||
lr = linear_model.LogisticRegression(solver='lbfgs')
|
|
||||||
param_grid = {
|
|
||||||
'C': np.logspace(-4, 2, 10),
|
|
||||||
}
|
|
||||||
pipe = model_selection.GridSearchCV(
|
|
||||||
lr, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
|
|
||||||
verbose=verbose)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def random_forest(verbose: int = 0, n_jobs: int = 1):
|
|
||||||
"""Setup a random forest pipeline with cross-validation."""
|
|
||||||
rf = ensemble.RandomForestClassifier()
|
|
||||||
|
|
||||||
n_estimators = [100]
|
|
||||||
max_features = ['auto', 'sqrt']
|
|
||||||
max_depth = [5, 10, 20]
|
|
||||||
max_depth.append(None)
|
|
||||||
min_samples_split = [2, 5, 10]
|
|
||||||
min_samples_leaf = [1, 2, 4]
|
|
||||||
random_grid = {'n_estimators': n_estimators,
|
|
||||||
'max_features': max_features,
|
|
||||||
'max_depth': max_depth,
|
|
||||||
'min_samples_split': min_samples_split,
|
|
||||||
'min_samples_leaf': min_samples_leaf}
|
|
||||||
|
|
||||||
pipe = model_selection.RandomizedSearchCV(
|
|
||||||
rf, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs,
|
|
||||||
iid=False, verbose=verbose)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def mlp(verbose: int = 0, n_jobs: int = 1):
|
|
||||||
"""Setup a MLP pipeline with cross-validation."""
|
|
||||||
mlpmodel = neural_network.MLPClassifier()
|
|
||||||
|
|
||||||
param_grid = {
|
|
||||||
'hidden_layer_sizes': [(64,), (32, 32)],
|
|
||||||
'solver': ['adam'],
|
|
||||||
'alpha': [0.0001, 0.001, 0.01],
|
|
||||||
}
|
|
||||||
pipe = model_selection.GridSearchCV(
|
|
||||||
mlpmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
|
|
||||||
verbose=verbose)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def knn(verbose: int = 0, n_jobs: int = 1):
|
|
||||||
"""Setup a k-nearest neighbors pipeline with cross-validation."""
|
|
||||||
knnmodel = neighbors.KNeighborsClassifier()
|
|
||||||
|
|
||||||
param_grid = {
|
|
||||||
'n_neighbors': [3, 5, 7],
|
|
||||||
}
|
|
||||||
pipe = model_selection.GridSearchCV(
|
|
||||||
knnmodel, param_grid=param_grid, cv=3, n_jobs=n_jobs, iid=False,
|
|
||||||
verbose=verbose)
|
|
||||||
return pipe
|
|
Loading…
Reference in a new issue