forked from 626_privacy/tensorflow_privacy
The new API for the membership inference attack.
1. Colab and Keras/TF estimator integration still use the old API and will be updated in the subsequent CLs. 2. After dropping the old API in membership_inference_attack.py, membership_inference_attack_new.py will be renamed in membership_inference_attack.py. PiperOrigin-RevId: 325823046
This commit is contained in:
parent
68651eeddc
commit
43a0e4be8a
9 changed files with 1279 additions and 1 deletions
|
@ -0,0 +1,322 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Data structures representing attack inputs, configuration, outputs."""
|
||||
import enum
|
||||
import pickle
|
||||
from typing import Any, Iterable, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from sklearn import metrics
|
||||
|
||||
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
|
||||
|
||||
|
||||
class SlicingFeature(enum.Enum):
|
||||
"""Enum with features by which slicing is available."""
|
||||
CLASS = 'class'
|
||||
PERCENTILE = 'percentile'
|
||||
CORRECTLY_CLASSIFIED = 'correctly_classfied'
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleSliceSpec:
|
||||
"""Specifies a slice.
|
||||
|
||||
The slice is defined by values in one feature - it might be a single value
|
||||
(eg. slice of examples of the specific classification class) or some set of
|
||||
values (eg. range of percentiles of the attacked model loss).
|
||||
|
||||
When feature is None, it means that the slice is the entire dataset.
|
||||
"""
|
||||
feature: SlicingFeature = None
|
||||
value: Any = None
|
||||
|
||||
@property
|
||||
def entire_dataset(self):
|
||||
return self.feature is None
|
||||
|
||||
def __str__(self):
|
||||
if self.entire_dataset:
|
||||
return 'Entire dataset'
|
||||
|
||||
if self.feature == SlicingFeature.PERCENTILE:
|
||||
return 'Loss percentiles: %d-%d' % self.value
|
||||
|
||||
return f'{self.feature.name}={self.value}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlicingSpec:
|
||||
"""Specification of a slicing procedure.
|
||||
|
||||
Each variable which is set specifies a slicing by different dimension.
|
||||
"""
|
||||
|
||||
# When is set to true, one of the slices is the whole dataset.
|
||||
entire_dataset: bool = True
|
||||
|
||||
# Used in classification tasks for slicing by classes. It is assumed that
|
||||
# classes are integers 0, 1, ... number of classes. When true one slice per
|
||||
# each class is generated.
|
||||
by_class: Union[bool, Iterable[int], int] = False
|
||||
|
||||
# if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
|
||||
# ... 90-100%.
|
||||
by_percentiles: bool = False
|
||||
|
||||
# When true, a slice for correctly classifed and a slice for misclassifed
|
||||
# examples will be generated.
|
||||
by_classification_correctness: bool = False
|
||||
|
||||
|
||||
class AttackType(enum.Enum):
|
||||
"""An enum define attack types."""
|
||||
LOGISTIC_REGRESSION = 'lr'
|
||||
MULTI_LAYERED_PERCEPTRON = 'mlp'
|
||||
RANDOM_FOREST = 'rf'
|
||||
K_NEAREST_NEIGHBORS = 'knn'
|
||||
THRESHOLD_ATTACK = 'threshold'
|
||||
|
||||
@property
|
||||
def is_trained_attack(self):
|
||||
"""Returns whether this type of attack requires training a model."""
|
||||
return self != AttackType.THRESHOLD_ATTACK
|
||||
|
||||
# Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION
|
||||
def __str__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackInputData:
|
||||
"""Input data for running an attack.
|
||||
|
||||
This includes only the data, and not configuration.
|
||||
"""
|
||||
|
||||
logits_train: np.ndarray = None
|
||||
logits_test: np.ndarray = None
|
||||
|
||||
# Contains ground-truth classes. Classes are assumed to be integers starting
|
||||
# from 0.
|
||||
labels_train: np.ndarray = None
|
||||
labels_test: np.ndarray = None
|
||||
|
||||
# Explicitly specified loss. If provided, this is used instead of deriving
|
||||
# loss from logits and labels
|
||||
loss_train: np.ndarray = None
|
||||
loss_test: np.ndarray = None
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
if self.labels_train is None or self.labels_test is None:
|
||||
raise ValueError(
|
||||
"Can't identify the number of classes as no labels were provided. "
|
||||
'Please set labels_train and labels_test')
|
||||
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(logits: np.ndarray, true_labels: np.ndarray):
|
||||
return logits[range(logits.shape[0]), true_labels]
|
||||
|
||||
def get_loss_train(self):
|
||||
"""Calculates cross-entropy losses for the training set."""
|
||||
if self.loss_train is not None:
|
||||
return self.loss_train
|
||||
return self._get_loss(self.logits_train, self.labels_train)
|
||||
|
||||
def get_loss_test(self):
|
||||
"""Calculates cross-entropy losses for the test set."""
|
||||
if self.loss_test is not None:
|
||||
return self.loss_test
|
||||
return self._get_loss(self.logits_test, self.labels_test)
|
||||
|
||||
def get_train_size(self):
|
||||
"""Returns size of the training set."""
|
||||
if self.loss_train is not None:
|
||||
return self.loss_train.size
|
||||
return self.logits_train.shape[0]
|
||||
|
||||
def get_test_size(self):
|
||||
"""Returns size of the test set."""
|
||||
if self.loss_test is not None:
|
||||
return self.loss_test.size
|
||||
return self.logits_test.shape[0]
|
||||
|
||||
def validate(self):
|
||||
"""Validates the inputs."""
|
||||
if (self.loss_train is None) != (self.loss_test is None):
|
||||
raise ValueError(
|
||||
'loss_test and loss_train should both be either set or unset')
|
||||
|
||||
if (self.logits_train is None) != (self.logits_test is None):
|
||||
raise ValueError(
|
||||
'logits_train and logits_test should both be either set or unset')
|
||||
|
||||
if (self.labels_train is None) != (self.labels_test is None):
|
||||
raise ValueError(
|
||||
'labels_train and labels_test should both be either set or unset')
|
||||
|
||||
if (self.labels_train is None and self.loss_train is None and
|
||||
self.logits_train is None):
|
||||
raise ValueError('At least one of labels, logits or losses should be set')
|
||||
|
||||
# TODO(b/161366709): Add checks for equal sizes
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocCurve:
|
||||
"""Represents ROC curve of a membership inference classifier."""
|
||||
# Thresholds used to define points on ROC curve.
|
||||
# Thresholds are not explicitly part of the curve, and are stored for
|
||||
# debugging purposes.
|
||||
thresholds: np.ndarray
|
||||
|
||||
# True positive rates based on thresholds
|
||||
tpr: np.ndarray
|
||||
|
||||
# False positive rates based on thresholds
|
||||
fpr: np.ndarray
|
||||
|
||||
def get_auc(self):
|
||||
"""Calculates area under curve (aka AUC)."""
|
||||
return metrics.auc(self.fpr, self.tpr)
|
||||
|
||||
def get_attacker_advantage(self):
|
||||
"""Calculates membership attacker's (or adversary's) advantage.
|
||||
|
||||
This metric is inspired by https://arxiv.org/abs/1709.01604, specifically
|
||||
by Definition 4. The difference here is that we calculate maximum advantage
|
||||
over all available classifier thresholds.
|
||||
|
||||
Returns:
|
||||
a single float number with membership attaker's advantage.
|
||||
"""
|
||||
return max(np.abs(self.tpr - self.fpr))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleAttackResult:
|
||||
"""Results from running a single attack."""
|
||||
|
||||
# Data slice this result was calculated for.
|
||||
slice_spec: SingleSliceSpec
|
||||
|
||||
attack_type: AttackType
|
||||
roc_curve: RocCurve # for drawing and metrics calculation
|
||||
|
||||
# TODO(b/162693190): Add more metrics. Think which info we should store
|
||||
# to derive metrics like f1_score or accuracy. Should we store labels and
|
||||
# predictions, or rather some aggregate data?
|
||||
|
||||
def get_attacker_advantage(self):
|
||||
return self.roc_curve.get_attacker_advantage()
|
||||
|
||||
def get_auc(self):
|
||||
return self.roc_curve.get_auc()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackResults:
|
||||
"""Results from running multiple attacks."""
|
||||
# add metadata, such as parameters of attack evaluation, input data etc
|
||||
single_attack_results: Iterable[SingleAttackResult]
|
||||
|
||||
def calculate_pd_dataframe(self):
|
||||
# returns all metrics as a Pandas DataFrame
|
||||
return
|
||||
|
||||
def summary(self, by_slices=False) -> str:
|
||||
"""Provides a summary of the metrics.
|
||||
|
||||
The summary provides the best-performing attacks for each requested data
|
||||
slice.
|
||||
Args:
|
||||
by_slices : whether to prepare a per-slice summary.
|
||||
|
||||
Returns:
|
||||
A string with a summary of all the metrics.
|
||||
"""
|
||||
summary = []
|
||||
|
||||
# Summary over all slices
|
||||
max_auc_result_all = self.get_result_with_max_attacker_advantage()
|
||||
summary.append('Best-performing attacks over all slices')
|
||||
summary.append(
|
||||
' %s achieved an AUC of %.2f on slice %s' %
|
||||
(max_auc_result_all.attack_type, max_auc_result_all.get_auc(),
|
||||
max_auc_result_all.slice_spec))
|
||||
|
||||
max_advantage_result_all = self.get_result_with_max_attacker_advantage()
|
||||
summary.append(' %s achieved an advantage of %.2f on slice %s' %
|
||||
(max_advantage_result_all.attack_type,
|
||||
max_advantage_result_all.get_attacker_advantage(),
|
||||
max_advantage_result_all.slice_spec))
|
||||
|
||||
slice_dict = self._group_results_by_slice()
|
||||
|
||||
if len(slice_dict.keys()) > 1 and by_slices:
|
||||
for slice_str in slice_dict:
|
||||
results = slice_dict[slice_str]
|
||||
summary.append('\nBest-performing attacks over slice: \"%s\"' %
|
||||
slice_str)
|
||||
max_auc_result = results.get_result_with_max_auc()
|
||||
summary.append(' %s achieved an AUC of %.2f' %
|
||||
(max_auc_result.attack_type, max_auc_result.get_auc()))
|
||||
max_advantage_result = results.get_result_with_max_attacker_advantage()
|
||||
summary.append(' %s achieved an advantage of %.2f' %
|
||||
(max_advantage_result.attack_type,
|
||||
max_advantage_result.get_attacker_advantage()))
|
||||
|
||||
return '\n'.join(summary)
|
||||
|
||||
def _group_results_by_slice(self):
|
||||
"""Groups AttackResults into a dictionary keyed by the slice."""
|
||||
slice_dict = {}
|
||||
for attack_result in self.single_attack_results:
|
||||
slice_str = str(attack_result.slice_spec)
|
||||
if slice_str not in slice_dict:
|
||||
slice_dict[slice_str] = AttackResults([])
|
||||
slice_dict[slice_str].single_attack_results.append(attack_result)
|
||||
return slice_dict
|
||||
|
||||
def get_result_with_max_auc(self) -> SingleAttackResult:
|
||||
"""Get the result with maximum AUC for all attacks and slices."""
|
||||
aucs = [result.get_auc() for result in self.single_attack_results]
|
||||
|
||||
if min(aucs) < 0.4:
|
||||
print('Suspiciously low AUC detected: %.2f. ' +
|
||||
'There might be a bug in the classifier' % min(aucs))
|
||||
|
||||
return self.single_attack_results[np.argmax(aucs)]
|
||||
|
||||
def get_result_with_max_attacker_advantage(self) -> SingleAttackResult:
|
||||
"""Get the result with maximum advantage for all attacks and slices."""
|
||||
return self.single_attack_results[np.argmax([
|
||||
result.get_attacker_advantage() for result in self.single_attack_results
|
||||
])]
|
||||
|
||||
def save(self, filepath):
|
||||
"""Saves self to a pickle file."""
|
||||
with open(filepath, 'wb') as out:
|
||||
pickle.dump(self, out)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
"""Loads AttackResults from a pickle file."""
|
||||
with open(filepath, 'rb') as inp:
|
||||
return pickle.load(inp)
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Specifying and creating AttackInputData slices."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
|
||||
|
||||
def _slice_if_not_none(a, idx):
|
||||
return None if a is None else a[idx]
|
||||
|
||||
|
||||
def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||
idx_test) -> AttackInputData:
|
||||
"""Slices train fields with with idx_train and test fields with and idx_test."""
|
||||
|
||||
result = AttackInputData()
|
||||
|
||||
# Slice train data.
|
||||
result.logits_train = _slice_if_not_none(data.logits_train, idx_train)
|
||||
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
||||
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
||||
|
||||
# Slice test data.
|
||||
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
||||
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
|
||||
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
||||
idx_train = data.labels_train == class_value
|
||||
idx_test = data.labels_test == class_value
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
|
||||
|
||||
def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
|
||||
to_percentile: float):
|
||||
"""Slices samples by loss percentiles."""
|
||||
|
||||
# Find from_percentile and to_percentile percentiles in losses.
|
||||
loss_train = data.get_loss_train()
|
||||
loss_test = data.get_loss_test()
|
||||
losses = np.concatenate((loss_train, loss_test))
|
||||
from_loss = np.percentile(losses, from_percentile)
|
||||
to_loss = np.percentile(losses, to_percentile)
|
||||
|
||||
idx_train = (from_loss <= loss_train) & (loss_train <= to_loss)
|
||||
idx_test = (from_loss <= loss_test) & (loss_test <= to_loss)
|
||||
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
|
||||
|
||||
def _indices_by_classification(logits, labels, correctly_classified):
|
||||
idx_correct = labels == np.argmax(logits, axis=1)
|
||||
return idx_correct if correctly_classified else np.invert(idx_correct)
|
||||
|
||||
|
||||
def _slice_by_classification_correctness(data: AttackInputData,
|
||||
correctly_classified: bool):
|
||||
idx_train = _indices_by_classification(data.logits_train, data.labels_train,
|
||||
correctly_classified)
|
||||
idx_test = _indices_by_classification(data.logits_test, data.labels_test,
|
||||
correctly_classified)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
|
||||
|
||||
def get_single_slice_specs(slicing_spec: SlicingSpec,
|
||||
num_classes: int = None) -> List[SingleSliceSpec]:
|
||||
"""Returns slices of data according to slicing_spec."""
|
||||
result = []
|
||||
|
||||
if slicing_spec.entire_dataset:
|
||||
result.append(SingleSliceSpec())
|
||||
|
||||
# Create slices by class.
|
||||
by_class = slicing_spec.by_class
|
||||
if isinstance(by_class, bool):
|
||||
if by_class:
|
||||
assert num_classes, "When by_class == True, num_classes should be given."
|
||||
assert 0 <= num_classes <= 1000, (
|
||||
f"Too much classes for slicing by classes. "
|
||||
f"Found {num_classes}.")
|
||||
for c in range(num_classes):
|
||||
result.append(SingleSliceSpec(SlicingFeature.CLASS, c))
|
||||
elif isinstance(by_class, int):
|
||||
result.append(SingleSliceSpec(SlicingFeature.CLASS, by_class))
|
||||
elif isinstance(by_class, collections.Iterable):
|
||||
for c in by_class:
|
||||
result.append(SingleSliceSpec(SlicingFeature.CLASS, c))
|
||||
|
||||
# Create slices by percentiles
|
||||
if slicing_spec.by_percentiles:
|
||||
for percent in range(0, 100, 10):
|
||||
result.append(
|
||||
SingleSliceSpec(SlicingFeature.PERCENTILE, (percent, percent + 10)))
|
||||
|
||||
# Create slices by correctness of the classifications.
|
||||
if slicing_spec.by_classification_correctness:
|
||||
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True))
|
||||
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_slice(data: AttackInputData,
|
||||
slice_spec: SingleSliceSpec) -> AttackInputData:
|
||||
"""Returns a single slice of data according to slice_spec."""
|
||||
if slice_spec.entire_dataset:
|
||||
data_slice = copy.copy(data)
|
||||
elif slice_spec.feature == SlicingFeature.CLASS:
|
||||
data_slice = _slice_by_class(data, slice_spec.value)
|
||||
elif slice_spec.feature == SlicingFeature.PERCENTILE:
|
||||
from_percentile, to_percentile = slice_spec.value
|
||||
data_slice = _slice_by_percentiles(data, from_percentile, to_percentile)
|
||||
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
|
||||
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
|
||||
else:
|
||||
raise ValueError(f'Unknown slice spec feature "{slice_spec.feature}"')
|
||||
|
||||
data_slice.slice_spec = slice_spec
|
||||
return data_slice
|
|
@ -0,0 +1,180 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_single_slice_specs
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
|
||||
|
||||
|
||||
def _are_all_fields_equal(lhs, rhs) -> bool:
|
||||
return vars(lhs) == vars(rhs)
|
||||
|
||||
|
||||
def _are_lists_equal(lhs, rhs) -> bool:
|
||||
if len(lhs) != len(rhs):
|
||||
return False
|
||||
for l, r in zip(lhs, rhs):
|
||||
if not _are_all_fields_equal(l, r):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class SingleSliceSpecsTest(absltest.TestCase):
|
||||
"""Tests for get_single_slice_specs."""
|
||||
|
||||
ENTIRE_DATASET_SLICE = SingleSliceSpec()
|
||||
|
||||
def test_no_slices(self):
|
||||
input_data = SlicingSpec(entire_dataset=False)
|
||||
expected = []
|
||||
output = get_single_slice_specs(input_data)
|
||||
self.assertTrue(_are_lists_equal(output, expected))
|
||||
|
||||
def test_entire_dataset(self):
|
||||
input_data = SlicingSpec()
|
||||
expected = [self.ENTIRE_DATASET_SLICE]
|
||||
output = get_single_slice_specs(input_data)
|
||||
self.assertTrue(_are_lists_equal(output, expected))
|
||||
|
||||
def test_slice_by_classes(self):
|
||||
input_data = SlicingSpec(by_class=True)
|
||||
n_classes = 5
|
||||
expected = [self.ENTIRE_DATASET_SLICE] + [
|
||||
SingleSliceSpec(SlicingFeature.CLASS, c) for c in range(n_classes)
|
||||
]
|
||||
output = get_single_slice_specs(input_data, n_classes)
|
||||
self.assertTrue(_are_lists_equal(output, expected))
|
||||
|
||||
def test_slice_by_percentiles(self):
|
||||
input_data = SlicingSpec(entire_dataset=False, by_percentiles=True)
|
||||
expected0 = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 10))
|
||||
expected5 = SingleSliceSpec(SlicingFeature.PERCENTILE, (50, 60))
|
||||
output = get_single_slice_specs(input_data)
|
||||
self.assertLen(output, 10)
|
||||
self.assertTrue(_are_all_fields_equal(output[0], expected0))
|
||||
self.assertTrue(_are_all_fields_equal(output[5], expected5))
|
||||
|
||||
def test_slice_by_correcness(self):
|
||||
input_data = SlicingSpec(
|
||||
entire_dataset=False, by_classification_correctness=True)
|
||||
expected = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True)
|
||||
output = get_single_slice_specs(input_data)
|
||||
self.assertLen(output, 2)
|
||||
self.assertTrue(_are_all_fields_equal(output[0], expected))
|
||||
|
||||
def test_slicing_by_multiple_features(self):
|
||||
input_data = SlicingSpec(
|
||||
entire_dataset=True,
|
||||
by_class=True,
|
||||
by_percentiles=True,
|
||||
by_classification_correctness=True)
|
||||
n_classes = 10
|
||||
expected_slices = n_classes
|
||||
expected_slices += 1 # entire dataset slice
|
||||
expected_slices += 10 # percentiles slices
|
||||
expected_slices += 2 # correcness classification slices
|
||||
output = get_single_slice_specs(input_data, n_classes)
|
||||
self.assertLen(output, expected_slices)
|
||||
|
||||
|
||||
class GetSliceTest(absltest.TestCase):
|
||||
|
||||
def __init__(self, methodname):
|
||||
"""Initialize the test class."""
|
||||
super().__init__(methodname)
|
||||
|
||||
# Create test data for 3 class classification task.
|
||||
logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]])
|
||||
logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]])
|
||||
labels_train = np.array([1, 0, 1, 2])
|
||||
labels_test = np.array([1, 2, 0, 2])
|
||||
loss_train = np.array([2, 0.25, 4, 3])
|
||||
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
||||
|
||||
self.input_data = AttackInputData(logits_train, logits_test, labels_train,
|
||||
labels_test, loss_train, loss_test)
|
||||
|
||||
def test_slice_entire_dataset(self):
|
||||
entire_dataset_slice = SingleSliceSpec()
|
||||
output = get_slice(self.input_data, entire_dataset_slice)
|
||||
expected = self.input_data
|
||||
expected.slice_spec = entire_dataset_slice
|
||||
self.assertTrue(_are_all_fields_equal(output, self.input_data))
|
||||
|
||||
def test_slice_by_class(self):
|
||||
class_index = 1
|
||||
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
|
||||
output = get_slice(self.input_data, class_slice)
|
||||
|
||||
# Check logits.
|
||||
self.assertLen(output.logits_train, 2)
|
||||
self.assertLen(output.logits_test, 1)
|
||||
self.assertTrue((output.logits_train[1] == [4, 5, 0]).all())
|
||||
|
||||
# Check labels.
|
||||
self.assertLen(output.labels_train, 2)
|
||||
self.assertLen(output.labels_test, 1)
|
||||
self.assertTrue((output.labels_train == class_index).all())
|
||||
self.assertTrue((output.labels_test == class_index).all())
|
||||
|
||||
# Check losses
|
||||
self.assertLen(output.loss_train, 2)
|
||||
self.assertLen(output.loss_test, 1)
|
||||
self.assertTrue((output.loss_train == [2, 4]).all())
|
||||
self.assertTrue((output.loss_test == [0.5]).all())
|
||||
|
||||
def test_slice_by_percentile(self):
|
||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||
output = get_slice(self.input_data, percentile_slice)
|
||||
|
||||
# Check logits.
|
||||
self.assertLen(output.logits_train, 3)
|
||||
self.assertLen(output.logits_test, 1)
|
||||
self.assertTrue((output.logits_test[0] == [10, 0, 11]).all())
|
||||
|
||||
# Check labels.
|
||||
self.assertLen(output.labels_train, 3)
|
||||
self.assertLen(output.labels_test, 1)
|
||||
self.assertTrue((output.labels_train == [1, 0, 2]).all())
|
||||
self.assertTrue((output.labels_test == [1]).all())
|
||||
|
||||
def test_slice_by_correctness(self):
|
||||
percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED,
|
||||
False)
|
||||
output = get_slice(self.input_data, percentile_slice)
|
||||
|
||||
# Check logits.
|
||||
self.assertLen(output.logits_train, 2)
|
||||
self.assertLen(output.logits_test, 3)
|
||||
self.assertTrue((output.logits_train[1] == [6, 7, 0]).all())
|
||||
self.assertTrue((output.logits_test[1] == [12, 13, 0]).all())
|
||||
|
||||
# Check labels.
|
||||
self.assertLen(output.labels_train, 2)
|
||||
self.assertLen(output.labels_test, 3)
|
||||
self.assertTrue((output.labels_train == [0, 2]).all())
|
||||
self.assertTrue((output.labels_test == [1, 2, 0]).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""An example for the membership inference attacks.
|
||||
|
||||
This is using a toy model based on classifying four spacial clusters of data.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.utils import to_categorical
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
|
||||
|
||||
def generate_random_cluster(center, scale, num_points):
|
||||
return np.random.normal(size=(num_points, len(center))) * scale + center
|
||||
|
||||
|
||||
def generate_features_and_labels(samples_per_cluster=250, scale=0.1):
|
||||
"""Generates noised 3D clusters."""
|
||||
cluster_centers = [[0, 0, 0], [2, 0, 0], [0, 2, 0], [0, 0, 2]]
|
||||
|
||||
features = np.concatenate((
|
||||
generate_random_cluster(
|
||||
center=cluster_centers[0],
|
||||
scale=scale,
|
||||
num_points=samples_per_cluster),
|
||||
generate_random_cluster(
|
||||
center=cluster_centers[1],
|
||||
scale=scale,
|
||||
num_points=samples_per_cluster),
|
||||
generate_random_cluster(
|
||||
center=cluster_centers[2],
|
||||
scale=scale,
|
||||
num_points=samples_per_cluster),
|
||||
generate_random_cluster(
|
||||
center=cluster_centers[3],
|
||||
scale=scale,
|
||||
num_points=samples_per_cluster),
|
||||
))
|
||||
|
||||
# Cluster labels: 0, 1, 2 and 3
|
||||
labels = np.concatenate((
|
||||
np.zeros(samples_per_cluster),
|
||||
np.ones(samples_per_cluster),
|
||||
np.ones(samples_per_cluster) * 2,
|
||||
np.ones(samples_per_cluster) * 3,
|
||||
))
|
||||
|
||||
return (features, labels)
|
||||
|
||||
|
||||
# Hint: Play with "noise_scale" for different levels of overlap between
|
||||
# the generated clusters. More noise makes the classification harder.
|
||||
noise_scale = 2
|
||||
training_features, training_labels = generate_features_and_labels(
|
||||
samples_per_cluster=250, scale=noise_scale)
|
||||
test_features, test_labels = generate_features_and_labels(
|
||||
samples_per_cluster=250, scale=noise_scale)
|
||||
|
||||
num_clusters = int(round(np.max(training_labels))) + 1
|
||||
|
||||
# Hint: play with the number of layers to achieve different level of
|
||||
# over-fitting and observe its effects on membership inference performance.
|
||||
model = keras.models.Sequential([
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(num_clusters, activation="relu"),
|
||||
layers.Softmax()
|
||||
])
|
||||
model.compile(
|
||||
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
|
||||
model.fit(
|
||||
training_features,
|
||||
to_categorical(training_labels, num_clusters),
|
||||
validation_data=(test_features, to_categorical(test_labels, num_clusters)),
|
||||
batch_size=64,
|
||||
epochs=10,
|
||||
shuffle=True)
|
||||
|
||||
training_pred = model.predict(training_features)
|
||||
test_pred = model.predict(test_features)
|
||||
|
||||
|
||||
def crossentropy(true_labels, predictions):
|
||||
return keras.backend.eval(
|
||||
keras.losses.binary_crossentropy(
|
||||
keras.backend.variable(to_categorical(true_labels, num_clusters)),
|
||||
keras.backend.variable(predictions)))
|
||||
|
||||
|
||||
attack_results = mia.run_attacks(
|
||||
AttackInputData(
|
||||
labels_train=training_labels,
|
||||
labels_test=test_labels,
|
||||
loss_train=crossentropy(training_labels, training_pred),
|
||||
loss_test=crossentropy(test_labels, test_pred)),
|
||||
SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION))
|
||||
|
||||
# Example of saving the results to the file and loading them back.
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
filepath = os.path.join(tmpdirname, "results.pickle")
|
||||
attack_results.save(filepath)
|
||||
loaded_results = AttackResults.load(filepath)
|
||||
|
||||
# Print attack metrics
|
||||
for attack_result in attack_results.single_attack_results:
|
||||
print("Slice: %s" % attack_result.slice_spec)
|
||||
print("Attack type: %s" % attack_result.attack_type)
|
||||
print("AUC: %.2f" % attack_result.roc_curve.get_auc())
|
||||
|
||||
print("Attacker advantage: %.2f\n" %
|
||||
attack_result.roc_curve.get_attacker_advantage())
|
||||
|
||||
max_auc_attacker = attack_results.get_result_with_max_attacker_advantage()
|
||||
print("Attack type with max AUC: %s, AUC of %.2f" %
|
||||
(max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))
|
||||
|
||||
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage()
|
||||
print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
|
||||
(max_advantage_attacker.attack_type,
|
||||
max_advantage_attacker.roc_curve.get_attacker_advantage()))
|
||||
|
||||
# Print summary
|
||||
print("Summary without slices: \n")
|
||||
print(attack_results.summary(by_slices=False))
|
||||
|
||||
print("Summary by slices: \n")
|
||||
print(attack_results.summary(by_slices=True))
|
|
@ -13,7 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Code that runs membership inference attacks based on the model outputs."""
|
||||
"""Code that runs membership inference attacks based on the model outputs.
|
||||
|
||||
Warning: This file belongs to the old API for membership inference attacks. This
|
||||
file will be removed soon. membership_inference_attack_new.py contains the new
|
||||
API.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import io
|
||||
|
@ -354,6 +359,11 @@ def run_attack(loss_train: np.ndarray = None,
|
|||
results: Dictionary with the chosen vulnerability metric(s) for all ran
|
||||
attacks.
|
||||
"""
|
||||
print(
|
||||
'Deprecation warning: function run_attack is '
|
||||
'deprecated and will be removed soon. '
|
||||
'Please use membership_inference_attack_new.run_attacks'
|
||||
)
|
||||
attacks = []
|
||||
features = {}
|
||||
# ---------- Check available data ----------
|
||||
|
@ -529,6 +539,11 @@ def run_all_attacks(loss_train: np.ndarray = None,
|
|||
Returns:
|
||||
result: dictionary with all attack results
|
||||
"""
|
||||
print(
|
||||
'Deprecation warning: function run_all_attacks is '
|
||||
'deprecated and will be removed soon. '
|
||||
'Please use membership_inference_attack_new.run_attacks'
|
||||
)
|
||||
metrics = ['auc', 'advantage']
|
||||
|
||||
# Entire data
|
||||
|
@ -631,6 +646,11 @@ def run_all_attacks_and_create_summary(
|
|||
result: a dictionary with all the distilled attack information summarized
|
||||
in the summarystring
|
||||
"""
|
||||
print(
|
||||
'Deprecation warning: function run_all_attacks_and_create_summary is '
|
||||
'deprecated and will be removed soon. '
|
||||
'Please use membership_inference_attack_new.run_attacks'
|
||||
)
|
||||
summary = []
|
||||
metrics = ['auc', 'advantage']
|
||||
attack_classifiers = ['lr', 'knn']
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Code that runs membership inference attacks based on the model outputs.
|
||||
|
||||
This file belongs to the new API for membership inference attacks. This file
|
||||
will be renamed to membership_inference_attack.py after the old API is removed.
|
||||
"""
|
||||
|
||||
from typing import Iterable
|
||||
import numpy as np
|
||||
from sklearn import metrics
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_single_slice_specs
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
|
||||
|
||||
|
||||
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
||||
if hasattr(data, 'slice_spec'):
|
||||
return data.slice_spec
|
||||
return SingleSliceSpec()
|
||||
|
||||
|
||||
def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
"""Classification attack done by ML models."""
|
||||
attacker = None
|
||||
|
||||
if attack_type == AttackType.LOGISTIC_REGRESSION:
|
||||
attacker = models.LogisticRegressionAttacker()
|
||||
elif attack_type == AttackType.MULTI_LAYERED_PERCEPTRON:
|
||||
attacker = models.MultilayerPerceptronAttacker()
|
||||
elif attack_type == AttackType.RANDOM_FOREST:
|
||||
attacker = models.RandomForestAttacker()
|
||||
elif attack_type == AttackType.K_NEAREST_NEIGHBORS:
|
||||
attacker = models.KNearestNeighborsAttacker()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Attack type {} not implemented yet.'.format(attack_type))
|
||||
|
||||
prepared_attacker_data = models.create_attacker_data(attack_input)
|
||||
|
||||
attacker.train_model(prepared_attacker_data.features_train,
|
||||
prepared_attacker_data.is_training_labels_train)
|
||||
|
||||
# Run the attacker on (permuted) test examples.
|
||||
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
||||
|
||||
# Generate ROC curves with predictions.
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
prepared_attacker_data.is_training_labels_test, predictions_test)
|
||||
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
|
||||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
attack_type=attack_type,
|
||||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
def run_threshold_attack(attack_input: AttackInputData):
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(attack_input.get_train_size()),
|
||||
np.ones(attack_input.get_test_size()))),
|
||||
np.concatenate(
|
||||
(attack_input.get_loss_train(), attack_input.get_loss_test())))
|
||||
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
|
||||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
def run_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
attack_input.validate()
|
||||
if attack_type.is_trained_attack:
|
||||
return run_trained_attack(attack_input, attack_type)
|
||||
|
||||
return run_threshold_attack(attack_input)
|
||||
|
||||
|
||||
def run_attacks(
|
||||
attack_input: AttackInputData,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)
|
||||
) -> AttackResults:
|
||||
"""Run all attacks."""
|
||||
attack_input.validate()
|
||||
attack_results = []
|
||||
|
||||
if slicing_spec is None:
|
||||
slicing_spec = SlicingSpec(entire_dataset=True)
|
||||
input_slice_specs = get_single_slice_specs(slicing_spec,
|
||||
attack_input.num_classes)
|
||||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
for attack_type in attack_types:
|
||||
attack_results.append(run_attack(attack_input_slice, attack_type))
|
||||
|
||||
return AttackResults(single_attack_results=attack_results)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
|
||||
|
||||
def get_test_input(n_train, n_test):
|
||||
"""Get example inputs for attacks."""
|
||||
rng = np.random.RandomState(4)
|
||||
return AttackInputData(
|
||||
rng.randn(n_train, 5) + 0.2,
|
||||
rng.randn(n_test, 5) + 0.2, np.array([i % 5 for i in range(n_train)]),
|
||||
np.array([i % 5 for i in range(n_test)]))
|
||||
|
||||
|
||||
class RunAttacksTest(absltest.TestCase):
|
||||
|
||||
def test_run_attacks_size(self):
|
||||
result = mia.run_attacks(
|
||||
get_test_input(100, 100), SlicingSpec(),
|
||||
(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION))
|
||||
|
||||
self.assertLen(result.single_attack_results, 2)
|
||||
|
||||
def test_run_attack_trained_sets_attack_type(self):
|
||||
result = mia.run_attack(
|
||||
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||
|
||||
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||
|
||||
def test_run_attack_threshold_sets_attack_type(self):
|
||||
result = mia.run_attack(
|
||||
get_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
def test_run_attack_threshold_calculates_correct_auc(self):
|
||||
result = mia.run_attack(
|
||||
AttackInputData(
|
||||
loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
|
||||
loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
|
||||
AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
|
||||
|
||||
def test_run_attack_by_slice(self):
|
||||
result = mia.run_attacks(
|
||||
get_test_input(100, 100), SlicingSpec(by_class=True),
|
||||
(AttackType.THRESHOLD_ATTACK,))
|
||||
|
||||
self.assertLen(result.single_attack_results, 6)
|
||||
expected_slice = SingleSliceSpec(SlicingFeature.CLASS, 2)
|
||||
self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
207
tensorflow_privacy/privacy/membership_inference_attack/models.py
Normal file
207
tensorflow_privacy/privacy/membership_inference_attack/models.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Trained models for membership inference attacks."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from sklearn import ensemble
|
||||
from sklearn import linear_model
|
||||
from sklearn import model_selection
|
||||
from sklearn import neighbors
|
||||
from sklearn import neural_network
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackerData:
|
||||
"""Input data for an ML classifier attack.
|
||||
|
||||
This includes only the data, and not configuration.
|
||||
"""
|
||||
|
||||
features_train: np.ndarray = None
|
||||
# element-wise boolean array denoting if the example was part of training.
|
||||
is_training_labels_train: np.ndarray = None
|
||||
|
||||
features_test: np.ndarray = None
|
||||
# element-wise boolean array denoting if the example was part of training.
|
||||
is_training_labels_test: np.ndarray = None
|
||||
|
||||
|
||||
def create_attacker_data(attack_input_data: AttackInputData,
|
||||
test_fraction: float = 0.25) -> AttackerData:
|
||||
"""Prepare AttackInputData to train ML attackers.
|
||||
|
||||
Combines logits and losses and performs a random train-test split.
|
||||
|
||||
Args:
|
||||
attack_input_data: Original AttackInputData
|
||||
test_fraction: Fraction of the dataset to include in the test split.
|
||||
|
||||
Returns:
|
||||
AttackerData.
|
||||
"""
|
||||
attack_input_train = _column_stack(attack_input_data.logits_train,
|
||||
attack_input_data.get_loss_train())
|
||||
attack_input_test = _column_stack(attack_input_data.logits_test,
|
||||
attack_input_data.get_loss_test())
|
||||
|
||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||
|
||||
labels_all = np.concatenate(((np.zeros(attack_input_data.get_train_size())),
|
||||
(np.ones(attack_input_data.get_test_size()))))
|
||||
|
||||
# Perform a train-test split
|
||||
features_train, features_test, \
|
||||
is_training_labels_train, is_training_labels_test = \
|
||||
model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction)
|
||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||
is_training_labels_test)
|
||||
|
||||
|
||||
def _column_stack(logits, loss):
|
||||
"""Stacks logits and losses.
|
||||
|
||||
In case that only one exists, returns that one.
|
||||
Args:
|
||||
logits: logits array
|
||||
loss: loss array
|
||||
|
||||
Returns:
|
||||
stacked logits and losses (or only one if both do not exist).
|
||||
"""
|
||||
if logits is None:
|
||||
return np.expand_dims(loss, axis=-1)
|
||||
if loss is None:
|
||||
return logits
|
||||
return np.column_stack((logits, loss))
|
||||
|
||||
|
||||
class TrainedAttacker:
|
||||
"""Base class for training attack models."""
|
||||
model = None
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
"""Train an attacker model.
|
||||
|
||||
This is trained on examples from train and test datasets.
|
||||
Args:
|
||||
input_features : array-like of shape (n_samples, n_features) Training
|
||||
vector, where n_samples is the number of samples and n_features is the
|
||||
number of features.
|
||||
is_training_labels : a vector of booleans of shape (n_samples, )
|
||||
representing whether the sample is in the training set or not.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self, input_features):
|
||||
"""Predicts whether input_features belongs to train or test.
|
||||
|
||||
Args:
|
||||
input_features : A vector of features with the same semantics as x_train
|
||||
passed to train_model.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LogisticRegressionAttacker(TrainedAttacker):
|
||||
"""Logistic regression attacker."""
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
lr = linear_model.LogisticRegression(solver='lbfgs')
|
||||
param_grid = {
|
||||
'C': np.logspace(-4, 2, 10),
|
||||
}
|
||||
model = model_selection.GridSearchCV(
|
||||
lr, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
||||
|
||||
|
||||
class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||
"""Multilayer perceptron attacker."""
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
mlp_model = neural_network.MLPClassifier()
|
||||
param_grid = {
|
||||
'hidden_layer_sizes': [(64,), (32, 32)],
|
||||
'solver': ['adam'],
|
||||
'alpha': [0.0001, 0.001, 0.01],
|
||||
}
|
||||
model = model_selection.GridSearchCV(
|
||||
mlp_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
||||
|
||||
|
||||
class RandomForestAttacker(TrainedAttacker):
|
||||
"""Random forest attacker."""
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
"""Setup a random forest pipeline with cross-validation."""
|
||||
rf_model = ensemble.RandomForestClassifier()
|
||||
|
||||
param_grid = {
|
||||
'n_estimators': [100],
|
||||
'max_features': ['auto', 'sqrt'],
|
||||
'max_depth': [5, 10, 20, None],
|
||||
'min_samples_split': [2, 5, 10],
|
||||
'min_samples_leaf': [1, 2, 4]
|
||||
}
|
||||
model = model_selection.GridSearchCV(
|
||||
rf_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
||||
|
||||
|
||||
class KNearestNeighborsAttacker(TrainedAttacker):
|
||||
"""K nearest neighbor attacker."""
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
knn_model = neighbors.KNeighborsClassifier()
|
||||
param_grid = {
|
||||
'n_neighbors': [3, 5, 7],
|
||||
}
|
||||
model = model_selection.GridSearchCV(
|
||||
knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures."""
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
|
||||
|
||||
class TrainedAttackerTest(absltest.TestCase):
|
||||
|
||||
def test_base_attacker_train_and_predict(self):
|
||||
base_attacker = models.TrainedAttacker()
|
||||
self.assertRaises(NotImplementedError, base_attacker.train_model, [], [])
|
||||
self.assertRaises(NotImplementedError, base_attacker.predict, [])
|
||||
|
||||
def test_predict_before_training(self):
|
||||
lr_attacker = models.LogisticRegressionAttacker()
|
||||
self.assertRaises(AssertionError, lr_attacker.predict, [])
|
||||
|
||||
def test_create_attacker_data_loss_only(self):
|
||||
attack_input = AttackInputData(
|
||||
loss_train=np.array([1]), loss_test=np.array([2]))
|
||||
attacker_data = models.create_attacker_data(attack_input, 0.5)
|
||||
self.assertLen(attacker_data.features_test, 1)
|
||||
self.assertLen(attacker_data.features_train, 1)
|
||||
|
||||
def test_create_attacker_data_loss_and_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1, 2], [5, 6]]),
|
||||
logits_test=np.array([[10, 11], [14, 15]]),
|
||||
loss_train=np.array([3, 7]),
|
||||
loss_test=np.array([12, 16]))
|
||||
attacker_data = models.create_attacker_data(attack_input, 0.25)
|
||||
self.assertLen(attacker_data.features_test, 1)
|
||||
self.assertLen(attacker_data.features_train, 3)
|
||||
|
||||
for i, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 3) # each feature has two logits and one loss
|
||||
expected = feature[:2] not in attack_input.logits_train
|
||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
Loading…
Reference in a new issue