From 43a0e4be8a23610a82f0fb40d0c859e66b25e4db Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Aug 2020 09:43:58 -0700 Subject: [PATCH] The new API for the membership inference attack. 1. Colab and Keras/TF estimator integration still use the old API and will be updated in the subsequent CLs. 2. After dropping the old API in membership_inference_attack.py, membership_inference_attack_new.py will be renamed in membership_inference_attack.py. PiperOrigin-RevId: 325823046 --- .../data_structures.py | 322 ++++++++++++++++++ .../dataset_slicing.py | 143 ++++++++ .../dataset_slicing_test.py | 180 ++++++++++ .../membership_inference_attack/example.py | 149 ++++++++ .../membership_inference_attack.py | 22 +- .../membership_inference_attack_new.py | 121 +++++++ .../membership_inference_attack_new_test.py | 77 +++++ .../membership_inference_attack/models.py | 207 +++++++++++ .../models_test.py | 59 ++++ 9 files changed, 1279 insertions(+), 1 deletion(-) create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/data_structures.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/example.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/models.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/models_test.py diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py new file mode 100644 index 0000000..5697096 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -0,0 +1,322 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Data structures representing attack inputs, configuration, outputs.""" +import enum +import pickle +from typing import Any, Iterable, Union + +from dataclasses import dataclass +import numpy as np +from sklearn import metrics + +ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)' + + +class SlicingFeature(enum.Enum): + """Enum with features by which slicing is available.""" + CLASS = 'class' + PERCENTILE = 'percentile' + CORRECTLY_CLASSIFIED = 'correctly_classfied' + + +@dataclass +class SingleSliceSpec: + """Specifies a slice. + + The slice is defined by values in one feature - it might be a single value + (eg. slice of examples of the specific classification class) or some set of + values (eg. range of percentiles of the attacked model loss). + + When feature is None, it means that the slice is the entire dataset. + """ + feature: SlicingFeature = None + value: Any = None + + @property + def entire_dataset(self): + return self.feature is None + + def __str__(self): + if self.entire_dataset: + return 'Entire dataset' + + if self.feature == SlicingFeature.PERCENTILE: + return 'Loss percentiles: %d-%d' % self.value + + return f'{self.feature.name}={self.value}' + + +@dataclass +class SlicingSpec: + """Specification of a slicing procedure. + + Each variable which is set specifies a slicing by different dimension. + """ + + # When is set to true, one of the slices is the whole dataset. + entire_dataset: bool = True + + # Used in classification tasks for slicing by classes. It is assumed that + # classes are integers 0, 1, ... number of classes. When true one slice per + # each class is generated. + by_class: Union[bool, Iterable[int], int] = False + + # if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%, + # ... 90-100%. + by_percentiles: bool = False + + # When true, a slice for correctly classifed and a slice for misclassifed + # examples will be generated. + by_classification_correctness: bool = False + + +class AttackType(enum.Enum): + """An enum define attack types.""" + LOGISTIC_REGRESSION = 'lr' + MULTI_LAYERED_PERCEPTRON = 'mlp' + RANDOM_FOREST = 'rf' + K_NEAREST_NEIGHBORS = 'knn' + THRESHOLD_ATTACK = 'threshold' + + @property + def is_trained_attack(self): + """Returns whether this type of attack requires training a model.""" + return self != AttackType.THRESHOLD_ATTACK + + # Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION + def __str__(self): + return f'{self.name}' + + +@dataclass +class AttackInputData: + """Input data for running an attack. + + This includes only the data, and not configuration. + """ + + logits_train: np.ndarray = None + logits_test: np.ndarray = None + + # Contains ground-truth classes. Classes are assumed to be integers starting + # from 0. + labels_train: np.ndarray = None + labels_test: np.ndarray = None + + # Explicitly specified loss. If provided, this is used instead of deriving + # loss from logits and labels + loss_train: np.ndarray = None + loss_test: np.ndarray = None + + @property + def num_classes(self): + if self.labels_train is None or self.labels_test is None: + raise ValueError( + "Can't identify the number of classes as no labels were provided. " + 'Please set labels_train and labels_test') + return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1 + + @staticmethod + def _get_loss(logits: np.ndarray, true_labels: np.ndarray): + return logits[range(logits.shape[0]), true_labels] + + def get_loss_train(self): + """Calculates cross-entropy losses for the training set.""" + if self.loss_train is not None: + return self.loss_train + return self._get_loss(self.logits_train, self.labels_train) + + def get_loss_test(self): + """Calculates cross-entropy losses for the test set.""" + if self.loss_test is not None: + return self.loss_test + return self._get_loss(self.logits_test, self.labels_test) + + def get_train_size(self): + """Returns size of the training set.""" + if self.loss_train is not None: + return self.loss_train.size + return self.logits_train.shape[0] + + def get_test_size(self): + """Returns size of the test set.""" + if self.loss_test is not None: + return self.loss_test.size + return self.logits_test.shape[0] + + def validate(self): + """Validates the inputs.""" + if (self.loss_train is None) != (self.loss_test is None): + raise ValueError( + 'loss_test and loss_train should both be either set or unset') + + if (self.logits_train is None) != (self.logits_test is None): + raise ValueError( + 'logits_train and logits_test should both be either set or unset') + + if (self.labels_train is None) != (self.labels_test is None): + raise ValueError( + 'labels_train and labels_test should both be either set or unset') + + if (self.labels_train is None and self.loss_train is None and + self.logits_train is None): + raise ValueError('At least one of labels, logits or losses should be set') + + # TODO(b/161366709): Add checks for equal sizes + + +@dataclass +class RocCurve: + """Represents ROC curve of a membership inference classifier.""" + # Thresholds used to define points on ROC curve. + # Thresholds are not explicitly part of the curve, and are stored for + # debugging purposes. + thresholds: np.ndarray + + # True positive rates based on thresholds + tpr: np.ndarray + + # False positive rates based on thresholds + fpr: np.ndarray + + def get_auc(self): + """Calculates area under curve (aka AUC).""" + return metrics.auc(self.fpr, self.tpr) + + def get_attacker_advantage(self): + """Calculates membership attacker's (or adversary's) advantage. + + This metric is inspired by https://arxiv.org/abs/1709.01604, specifically + by Definition 4. The difference here is that we calculate maximum advantage + over all available classifier thresholds. + + Returns: + a single float number with membership attaker's advantage. + """ + return max(np.abs(self.tpr - self.fpr)) + + +@dataclass +class SingleAttackResult: + """Results from running a single attack.""" + + # Data slice this result was calculated for. + slice_spec: SingleSliceSpec + + attack_type: AttackType + roc_curve: RocCurve # for drawing and metrics calculation + + # TODO(b/162693190): Add more metrics. Think which info we should store + # to derive metrics like f1_score or accuracy. Should we store labels and + # predictions, or rather some aggregate data? + + def get_attacker_advantage(self): + return self.roc_curve.get_attacker_advantage() + + def get_auc(self): + return self.roc_curve.get_auc() + + +@dataclass +class AttackResults: + """Results from running multiple attacks.""" + # add metadata, such as parameters of attack evaluation, input data etc + single_attack_results: Iterable[SingleAttackResult] + + def calculate_pd_dataframe(self): + # returns all metrics as a Pandas DataFrame + return + + def summary(self, by_slices=False) -> str: + """Provides a summary of the metrics. + + The summary provides the best-performing attacks for each requested data + slice. + Args: + by_slices : whether to prepare a per-slice summary. + + Returns: + A string with a summary of all the metrics. + """ + summary = [] + + # Summary over all slices + max_auc_result_all = self.get_result_with_max_attacker_advantage() + summary.append('Best-performing attacks over all slices') + summary.append( + ' %s achieved an AUC of %.2f on slice %s' % + (max_auc_result_all.attack_type, max_auc_result_all.get_auc(), + max_auc_result_all.slice_spec)) + + max_advantage_result_all = self.get_result_with_max_attacker_advantage() + summary.append(' %s achieved an advantage of %.2f on slice %s' % + (max_advantage_result_all.attack_type, + max_advantage_result_all.get_attacker_advantage(), + max_advantage_result_all.slice_spec)) + + slice_dict = self._group_results_by_slice() + + if len(slice_dict.keys()) > 1 and by_slices: + for slice_str in slice_dict: + results = slice_dict[slice_str] + summary.append('\nBest-performing attacks over slice: \"%s\"' % + slice_str) + max_auc_result = results.get_result_with_max_auc() + summary.append(' %s achieved an AUC of %.2f' % + (max_auc_result.attack_type, max_auc_result.get_auc())) + max_advantage_result = results.get_result_with_max_attacker_advantage() + summary.append(' %s achieved an advantage of %.2f' % + (max_advantage_result.attack_type, + max_advantage_result.get_attacker_advantage())) + + return '\n'.join(summary) + + def _group_results_by_slice(self): + """Groups AttackResults into a dictionary keyed by the slice.""" + slice_dict = {} + for attack_result in self.single_attack_results: + slice_str = str(attack_result.slice_spec) + if slice_str not in slice_dict: + slice_dict[slice_str] = AttackResults([]) + slice_dict[slice_str].single_attack_results.append(attack_result) + return slice_dict + + def get_result_with_max_auc(self) -> SingleAttackResult: + """Get the result with maximum AUC for all attacks and slices.""" + aucs = [result.get_auc() for result in self.single_attack_results] + + if min(aucs) < 0.4: + print('Suspiciously low AUC detected: %.2f. ' + + 'There might be a bug in the classifier' % min(aucs)) + + return self.single_attack_results[np.argmax(aucs)] + + def get_result_with_max_attacker_advantage(self) -> SingleAttackResult: + """Get the result with maximum advantage for all attacks and slices.""" + return self.single_attack_results[np.argmax([ + result.get_attacker_advantage() for result in self.single_attack_results + ])] + + def save(self, filepath): + """Saves self to a pickle file.""" + with open(filepath, 'wb') as out: + pickle.dump(self, out) + + @classmethod + def load(cls, filepath): + """Loads AttackResults from a pickle file.""" + with open(filepath, 'rb') as inp: + return pickle.load(inp) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py new file mode 100644 index 0000000..99b8dec --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py @@ -0,0 +1,143 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Specifying and creating AttackInputData slices.""" + +import collections +import copy +from typing import List + +import numpy as np +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec + + +def _slice_if_not_none(a, idx): + return None if a is None else a[idx] + + +def _slice_data_by_indices(data: AttackInputData, idx_train, + idx_test) -> AttackInputData: + """Slices train fields with with idx_train and test fields with and idx_test.""" + + result = AttackInputData() + + # Slice train data. + result.logits_train = _slice_if_not_none(data.logits_train, idx_train) + result.labels_train = _slice_if_not_none(data.labels_train, idx_train) + result.loss_train = _slice_if_not_none(data.loss_train, idx_train) + + # Slice test data. + result.logits_test = _slice_if_not_none(data.logits_test, idx_test) + result.labels_test = _slice_if_not_none(data.labels_test, idx_test) + result.loss_test = _slice_if_not_none(data.loss_test, idx_test) + + return result + + +def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData: + idx_train = data.labels_train == class_value + idx_test = data.labels_test == class_value + return _slice_data_by_indices(data, idx_train, idx_test) + + +def _slice_by_percentiles(data: AttackInputData, from_percentile: float, + to_percentile: float): + """Slices samples by loss percentiles.""" + + # Find from_percentile and to_percentile percentiles in losses. + loss_train = data.get_loss_train() + loss_test = data.get_loss_test() + losses = np.concatenate((loss_train, loss_test)) + from_loss = np.percentile(losses, from_percentile) + to_loss = np.percentile(losses, to_percentile) + + idx_train = (from_loss <= loss_train) & (loss_train <= to_loss) + idx_test = (from_loss <= loss_test) & (loss_test <= to_loss) + + return _slice_data_by_indices(data, idx_train, idx_test) + + +def _indices_by_classification(logits, labels, correctly_classified): + idx_correct = labels == np.argmax(logits, axis=1) + return idx_correct if correctly_classified else np.invert(idx_correct) + + +def _slice_by_classification_correctness(data: AttackInputData, + correctly_classified: bool): + idx_train = _indices_by_classification(data.logits_train, data.labels_train, + correctly_classified) + idx_test = _indices_by_classification(data.logits_test, data.labels_test, + correctly_classified) + return _slice_data_by_indices(data, idx_train, idx_test) + + +def get_single_slice_specs(slicing_spec: SlicingSpec, + num_classes: int = None) -> List[SingleSliceSpec]: + """Returns slices of data according to slicing_spec.""" + result = [] + + if slicing_spec.entire_dataset: + result.append(SingleSliceSpec()) + + # Create slices by class. + by_class = slicing_spec.by_class + if isinstance(by_class, bool): + if by_class: + assert num_classes, "When by_class == True, num_classes should be given." + assert 0 <= num_classes <= 1000, ( + f"Too much classes for slicing by classes. " + f"Found {num_classes}.") + for c in range(num_classes): + result.append(SingleSliceSpec(SlicingFeature.CLASS, c)) + elif isinstance(by_class, int): + result.append(SingleSliceSpec(SlicingFeature.CLASS, by_class)) + elif isinstance(by_class, collections.Iterable): + for c in by_class: + result.append(SingleSliceSpec(SlicingFeature.CLASS, c)) + + # Create slices by percentiles + if slicing_spec.by_percentiles: + for percent in range(0, 100, 10): + result.append( + SingleSliceSpec(SlicingFeature.PERCENTILE, (percent, percent + 10))) + + # Create slices by correctness of the classifications. + if slicing_spec.by_classification_correctness: + result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True)) + result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False)) + + return result + + +def get_slice(data: AttackInputData, + slice_spec: SingleSliceSpec) -> AttackInputData: + """Returns a single slice of data according to slice_spec.""" + if slice_spec.entire_dataset: + data_slice = copy.copy(data) + elif slice_spec.feature == SlicingFeature.CLASS: + data_slice = _slice_by_class(data, slice_spec.value) + elif slice_spec.feature == SlicingFeature.PERCENTILE: + from_percentile, to_percentile = slice_spec.value + data_slice = _slice_by_percentiles(data, from_percentile, to_percentile) + elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED: + data_slice = _slice_by_classification_correctness(data, slice_spec.value) + else: + raise ValueError(f'Unknown slice spec feature "{slice_spec.feature}"') + + data_slice.slice_spec = slice_spec + return data_slice diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py new file mode 100644 index 0000000..e6570a3 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py @@ -0,0 +1,180 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing.""" + +from absl.testing import absltest +import numpy as np + +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec +from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_single_slice_specs +from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice + + +def _are_all_fields_equal(lhs, rhs) -> bool: + return vars(lhs) == vars(rhs) + + +def _are_lists_equal(lhs, rhs) -> bool: + if len(lhs) != len(rhs): + return False + for l, r in zip(lhs, rhs): + if not _are_all_fields_equal(l, r): + return False + return True + + +class SingleSliceSpecsTest(absltest.TestCase): + """Tests for get_single_slice_specs.""" + + ENTIRE_DATASET_SLICE = SingleSliceSpec() + + def test_no_slices(self): + input_data = SlicingSpec(entire_dataset=False) + expected = [] + output = get_single_slice_specs(input_data) + self.assertTrue(_are_lists_equal(output, expected)) + + def test_entire_dataset(self): + input_data = SlicingSpec() + expected = [self.ENTIRE_DATASET_SLICE] + output = get_single_slice_specs(input_data) + self.assertTrue(_are_lists_equal(output, expected)) + + def test_slice_by_classes(self): + input_data = SlicingSpec(by_class=True) + n_classes = 5 + expected = [self.ENTIRE_DATASET_SLICE] + [ + SingleSliceSpec(SlicingFeature.CLASS, c) for c in range(n_classes) + ] + output = get_single_slice_specs(input_data, n_classes) + self.assertTrue(_are_lists_equal(output, expected)) + + def test_slice_by_percentiles(self): + input_data = SlicingSpec(entire_dataset=False, by_percentiles=True) + expected0 = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 10)) + expected5 = SingleSliceSpec(SlicingFeature.PERCENTILE, (50, 60)) + output = get_single_slice_specs(input_data) + self.assertLen(output, 10) + self.assertTrue(_are_all_fields_equal(output[0], expected0)) + self.assertTrue(_are_all_fields_equal(output[5], expected5)) + + def test_slice_by_correcness(self): + input_data = SlicingSpec( + entire_dataset=False, by_classification_correctness=True) + expected = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True) + output = get_single_slice_specs(input_data) + self.assertLen(output, 2) + self.assertTrue(_are_all_fields_equal(output[0], expected)) + + def test_slicing_by_multiple_features(self): + input_data = SlicingSpec( + entire_dataset=True, + by_class=True, + by_percentiles=True, + by_classification_correctness=True) + n_classes = 10 + expected_slices = n_classes + expected_slices += 1 # entire dataset slice + expected_slices += 10 # percentiles slices + expected_slices += 2 # correcness classification slices + output = get_single_slice_specs(input_data, n_classes) + self.assertLen(output, expected_slices) + + +class GetSliceTest(absltest.TestCase): + + def __init__(self, methodname): + """Initialize the test class.""" + super().__init__(methodname) + + # Create test data for 3 class classification task. + logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]]) + logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]]) + labels_train = np.array([1, 0, 1, 2]) + labels_test = np.array([1, 2, 0, 2]) + loss_train = np.array([2, 0.25, 4, 3]) + loss_test = np.array([0.5, 3.5, 7, 4.5]) + + self.input_data = AttackInputData(logits_train, logits_test, labels_train, + labels_test, loss_train, loss_test) + + def test_slice_entire_dataset(self): + entire_dataset_slice = SingleSliceSpec() + output = get_slice(self.input_data, entire_dataset_slice) + expected = self.input_data + expected.slice_spec = entire_dataset_slice + self.assertTrue(_are_all_fields_equal(output, self.input_data)) + + def test_slice_by_class(self): + class_index = 1 + class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index) + output = get_slice(self.input_data, class_slice) + + # Check logits. + self.assertLen(output.logits_train, 2) + self.assertLen(output.logits_test, 1) + self.assertTrue((output.logits_train[1] == [4, 5, 0]).all()) + + # Check labels. + self.assertLen(output.labels_train, 2) + self.assertLen(output.labels_test, 1) + self.assertTrue((output.labels_train == class_index).all()) + self.assertTrue((output.labels_test == class_index).all()) + + # Check losses + self.assertLen(output.loss_train, 2) + self.assertLen(output.loss_test, 1) + self.assertTrue((output.loss_train == [2, 4]).all()) + self.assertTrue((output.loss_test == [0.5]).all()) + + def test_slice_by_percentile(self): + percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50)) + output = get_slice(self.input_data, percentile_slice) + + # Check logits. + self.assertLen(output.logits_train, 3) + self.assertLen(output.logits_test, 1) + self.assertTrue((output.logits_test[0] == [10, 0, 11]).all()) + + # Check labels. + self.assertLen(output.labels_train, 3) + self.assertLen(output.labels_test, 1) + self.assertTrue((output.labels_train == [1, 0, 2]).all()) + self.assertTrue((output.labels_test == [1]).all()) + + def test_slice_by_correctness(self): + percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, + False) + output = get_slice(self.input_data, percentile_slice) + + # Check logits. + self.assertLen(output.logits_train, 2) + self.assertLen(output.logits_test, 3) + self.assertTrue((output.logits_train[1] == [6, 7, 0]).all()) + self.assertTrue((output.logits_test[1] == [12, 13, 0]).all()) + + # Check labels. + self.assertLen(output.labels_train, 2) + self.assertLen(output.labels_test, 3) + self.assertTrue((output.labels_train == [0, 2]).all()) + self.assertTrue((output.labels_test == [1, 2, 0]).all()) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py new file mode 100644 index 0000000..804a997 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -0,0 +1,149 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""An example for the membership inference attacks. + +This is using a toy model based on classifying four spacial clusters of data. +""" +import os +import tempfile +import numpy as np +from tensorflow import keras +from tensorflow.keras import layers +from tensorflow.keras.utils import to_categorical +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec + + +def generate_random_cluster(center, scale, num_points): + return np.random.normal(size=(num_points, len(center))) * scale + center + + +def generate_features_and_labels(samples_per_cluster=250, scale=0.1): + """Generates noised 3D clusters.""" + cluster_centers = [[0, 0, 0], [2, 0, 0], [0, 2, 0], [0, 0, 2]] + + features = np.concatenate(( + generate_random_cluster( + center=cluster_centers[0], + scale=scale, + num_points=samples_per_cluster), + generate_random_cluster( + center=cluster_centers[1], + scale=scale, + num_points=samples_per_cluster), + generate_random_cluster( + center=cluster_centers[2], + scale=scale, + num_points=samples_per_cluster), + generate_random_cluster( + center=cluster_centers[3], + scale=scale, + num_points=samples_per_cluster), + )) + + # Cluster labels: 0, 1, 2 and 3 + labels = np.concatenate(( + np.zeros(samples_per_cluster), + np.ones(samples_per_cluster), + np.ones(samples_per_cluster) * 2, + np.ones(samples_per_cluster) * 3, + )) + + return (features, labels) + + +# Hint: Play with "noise_scale" for different levels of overlap between +# the generated clusters. More noise makes the classification harder. +noise_scale = 2 +training_features, training_labels = generate_features_and_labels( + samples_per_cluster=250, scale=noise_scale) +test_features, test_labels = generate_features_and_labels( + samples_per_cluster=250, scale=noise_scale) + +num_clusters = int(round(np.max(training_labels))) + 1 + +# Hint: play with the number of layers to achieve different level of +# over-fitting and observe its effects on membership inference performance. +model = keras.models.Sequential([ + layers.Dense(300, activation="relu"), + layers.Dense(300, activation="relu"), + layers.Dense(300, activation="relu"), + layers.Dense(num_clusters, activation="relu"), + layers.Softmax() +]) +model.compile( + optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) +model.fit( + training_features, + to_categorical(training_labels, num_clusters), + validation_data=(test_features, to_categorical(test_labels, num_clusters)), + batch_size=64, + epochs=10, + shuffle=True) + +training_pred = model.predict(training_features) +test_pred = model.predict(test_features) + + +def crossentropy(true_labels, predictions): + return keras.backend.eval( + keras.losses.binary_crossentropy( + keras.backend.variable(to_categorical(true_labels, num_clusters)), + keras.backend.variable(predictions))) + + +attack_results = mia.run_attacks( + AttackInputData( + labels_train=training_labels, + labels_test=test_labels, + loss_train=crossentropy(training_labels, training_pred), + loss_test=crossentropy(test_labels, test_pred)), + SlicingSpec(entire_dataset=True, by_class=True), + attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION)) + +# Example of saving the results to the file and loading them back. +with tempfile.TemporaryDirectory() as tmpdirname: + filepath = os.path.join(tmpdirname, "results.pickle") + attack_results.save(filepath) + loaded_results = AttackResults.load(filepath) + +# Print attack metrics +for attack_result in attack_results.single_attack_results: + print("Slice: %s" % attack_result.slice_spec) + print("Attack type: %s" % attack_result.attack_type) + print("AUC: %.2f" % attack_result.roc_curve.get_auc()) + + print("Attacker advantage: %.2f\n" % + attack_result.roc_curve.get_attacker_advantage()) + +max_auc_attacker = attack_results.get_result_with_max_attacker_advantage() +print("Attack type with max AUC: %s, AUC of %.2f" % + (max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc())) + +max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage() +print("Attack type with max advantage: %s, Attacker advantage of %.2f" % + (max_advantage_attacker.attack_type, + max_advantage_attacker.roc_curve.get_attacker_advantage())) + +# Print summary +print("Summary without slices: \n") +print(attack_results.summary(by_slices=False)) + +print("Summary by slices: \n") +print(attack_results.summary(by_slices=True)) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index c06fa36..5c14b36 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -13,7 +13,12 @@ # limitations under the License. # Lint as: python3 -"""Code that runs membership inference attacks based on the model outputs.""" +"""Code that runs membership inference attacks based on the model outputs. + +Warning: This file belongs to the old API for membership inference attacks. This +file will be removed soon. membership_inference_attack_new.py contains the new +API. +""" import collections import io @@ -354,6 +359,11 @@ def run_attack(loss_train: np.ndarray = None, results: Dictionary with the chosen vulnerability metric(s) for all ran attacks. """ + print( + 'Deprecation warning: function run_attack is ' + 'deprecated and will be removed soon. ' + 'Please use membership_inference_attack_new.run_attacks' + ) attacks = [] features = {} # ---------- Check available data ---------- @@ -529,6 +539,11 @@ def run_all_attacks(loss_train: np.ndarray = None, Returns: result: dictionary with all attack results """ + print( + 'Deprecation warning: function run_all_attacks is ' + 'deprecated and will be removed soon. ' + 'Please use membership_inference_attack_new.run_attacks' + ) metrics = ['auc', 'advantage'] # Entire data @@ -631,6 +646,11 @@ def run_all_attacks_and_create_summary( result: a dictionary with all the distilled attack information summarized in the summarystring """ + print( + 'Deprecation warning: function run_all_attacks_and_create_summary is ' + 'deprecated and will be removed soon. ' + 'Please use membership_inference_attack_new.run_attacks' + ) summary = [] metrics = ['auc', 'advantage'] attack_classifiers = ['lr', 'knn'] diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py new file mode 100644 index 0000000..6eb8d49 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py @@ -0,0 +1,121 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Code that runs membership inference attacks based on the model outputs. + +This file belongs to the new API for membership inference attacks. This file +will be renamed to membership_inference_attack.py after the old API is removed. +""" + +from typing import Iterable +import numpy as np +from sklearn import metrics + +from tensorflow_privacy.privacy.membership_inference_attack import models +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec +from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_single_slice_specs +from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice + + +def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: + if hasattr(data, 'slice_spec'): + return data.slice_spec + return SingleSliceSpec() + + +def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType): + """Classification attack done by ML models.""" + attacker = None + + if attack_type == AttackType.LOGISTIC_REGRESSION: + attacker = models.LogisticRegressionAttacker() + elif attack_type == AttackType.MULTI_LAYERED_PERCEPTRON: + attacker = models.MultilayerPerceptronAttacker() + elif attack_type == AttackType.RANDOM_FOREST: + attacker = models.RandomForestAttacker() + elif attack_type == AttackType.K_NEAREST_NEIGHBORS: + attacker = models.KNearestNeighborsAttacker() + else: + raise NotImplementedError( + 'Attack type {} not implemented yet.'.format(attack_type)) + + prepared_attacker_data = models.create_attacker_data(attack_input) + + attacker.train_model(prepared_attacker_data.features_train, + prepared_attacker_data.is_training_labels_train) + + # Run the attacker on (permuted) test examples. + predictions_test = attacker.predict(prepared_attacker_data.features_test) + + # Generate ROC curves with predictions. + fpr, tpr, thresholds = metrics.roc_curve( + prepared_attacker_data.is_training_labels_test, predictions_test) + + roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + + return SingleAttackResult( + slice_spec=_get_slice_spec(attack_input), + attack_type=attack_type, + roc_curve=roc_curve) + + +def run_threshold_attack(attack_input: AttackInputData): + fpr, tpr, thresholds = metrics.roc_curve( + np.concatenate((np.zeros(attack_input.get_train_size()), + np.ones(attack_input.get_test_size()))), + np.concatenate( + (attack_input.get_loss_train(), attack_input.get_loss_test()))) + + roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + + return SingleAttackResult( + slice_spec=_get_slice_spec(attack_input), + attack_type=AttackType.THRESHOLD_ATTACK, + roc_curve=roc_curve) + + +def run_attack(attack_input: AttackInputData, attack_type: AttackType): + attack_input.validate() + if attack_type.is_trained_attack: + return run_trained_attack(attack_input, attack_type) + + return run_threshold_attack(attack_input) + + +def run_attacks( + attack_input: AttackInputData, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,) +) -> AttackResults: + """Run all attacks.""" + attack_input.validate() + attack_results = [] + + if slicing_spec is None: + slicing_spec = SlicingSpec(entire_dataset=True) + input_slice_specs = get_single_slice_specs(slicing_spec, + attack_input.num_classes) + for single_slice_spec in input_slice_specs: + attack_input_slice = get_slice(attack_input, single_slice_spec) + for attack_type in attack_types: + attack_results.append(run_attack(attack_input_slice, attack_type)) + + return AttackResults(single_attack_results=attack_results) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py new file mode 100644 index 0000000..ae291fb --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py @@ -0,0 +1,77 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils.""" +from absl.testing import absltest +import numpy as np +from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec + + +def get_test_input(n_train, n_test): + """Get example inputs for attacks.""" + rng = np.random.RandomState(4) + return AttackInputData( + rng.randn(n_train, 5) + 0.2, + rng.randn(n_test, 5) + 0.2, np.array([i % 5 for i in range(n_train)]), + np.array([i % 5 for i in range(n_test)])) + + +class RunAttacksTest(absltest.TestCase): + + def test_run_attacks_size(self): + result = mia.run_attacks( + get_test_input(100, 100), SlicingSpec(), + (AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION)) + + self.assertLen(result.single_attack_results, 2) + + def test_run_attack_trained_sets_attack_type(self): + result = mia.run_attack( + get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION) + + self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION) + + def test_run_attack_threshold_sets_attack_type(self): + result = mia.run_attack( + get_test_input(100, 100), AttackType.THRESHOLD_ATTACK) + + self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK) + + def test_run_attack_threshold_calculates_correct_auc(self): + result = mia.run_attack( + AttackInputData( + loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]), + loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])), + AttackType.THRESHOLD_ATTACK) + + np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) + + def test_run_attack_by_slice(self): + result = mia.run_attacks( + get_test_input(100, 100), SlicingSpec(by_class=True), + (AttackType.THRESHOLD_ATTACK,)) + + self.assertLen(result.single_attack_results, 6) + expected_slice = SingleSliceSpec(SlicingFeature.CLASS, 2) + self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py new file mode 100644 index 0000000..b4e7056 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -0,0 +1,207 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Trained models for membership inference attacks.""" + +from dataclasses import dataclass +import numpy as np +from sklearn import ensemble +from sklearn import linear_model +from sklearn import model_selection +from sklearn import neighbors +from sklearn import neural_network + +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData + + +@dataclass +class AttackerData: + """Input data for an ML classifier attack. + + This includes only the data, and not configuration. + """ + + features_train: np.ndarray = None + # element-wise boolean array denoting if the example was part of training. + is_training_labels_train: np.ndarray = None + + features_test: np.ndarray = None + # element-wise boolean array denoting if the example was part of training. + is_training_labels_test: np.ndarray = None + + +def create_attacker_data(attack_input_data: AttackInputData, + test_fraction: float = 0.25) -> AttackerData: + """Prepare AttackInputData to train ML attackers. + + Combines logits and losses and performs a random train-test split. + + Args: + attack_input_data: Original AttackInputData + test_fraction: Fraction of the dataset to include in the test split. + + Returns: + AttackerData. + """ + attack_input_train = _column_stack(attack_input_data.logits_train, + attack_input_data.get_loss_train()) + attack_input_test = _column_stack(attack_input_data.logits_test, + attack_input_data.get_loss_test()) + + features_all = np.concatenate((attack_input_train, attack_input_test)) + + labels_all = np.concatenate(((np.zeros(attack_input_data.get_train_size())), + (np.ones(attack_input_data.get_test_size())))) + + # Perform a train-test split + features_train, features_test, \ + is_training_labels_train, is_training_labels_test = \ + model_selection.train_test_split( + features_all, labels_all, test_size=test_fraction) + return AttackerData(features_train, is_training_labels_train, features_test, + is_training_labels_test) + + +def _column_stack(logits, loss): + """Stacks logits and losses. + + In case that only one exists, returns that one. + Args: + logits: logits array + loss: loss array + + Returns: + stacked logits and losses (or only one if both do not exist). + """ + if logits is None: + return np.expand_dims(loss, axis=-1) + if loss is None: + return logits + return np.column_stack((logits, loss)) + + +class TrainedAttacker: + """Base class for training attack models.""" + model = None + + def train_model(self, input_features, is_training_labels): + """Train an attacker model. + + This is trained on examples from train and test datasets. + Args: + input_features : array-like of shape (n_samples, n_features) Training + vector, where n_samples is the number of samples and n_features is the + number of features. + is_training_labels : a vector of booleans of shape (n_samples, ) + representing whether the sample is in the training set or not. + """ + raise NotImplementedError() + + def predict(self, input_features): + """Predicts whether input_features belongs to train or test. + + Args: + input_features : A vector of features with the same semantics as x_train + passed to train_model. + """ + raise NotImplementedError() + + +class LogisticRegressionAttacker(TrainedAttacker): + """Logistic regression attacker.""" + + def train_model(self, input_features, is_training_labels): + lr = linear_model.LogisticRegression(solver='lbfgs') + param_grid = { + 'C': np.logspace(-4, 2, 10), + } + model = model_selection.GridSearchCV( + lr, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) + model.fit(input_features, is_training_labels) + self.model = model + + def predict(self, input_features): + if self.model is None: + raise AssertionError( + 'Model not trained yet. Please call train_model first.') + return self.model.predict(input_features) + + +class MultilayerPerceptronAttacker(TrainedAttacker): + """Multilayer perceptron attacker.""" + + def train_model(self, input_features, is_training_labels): + mlp_model = neural_network.MLPClassifier() + param_grid = { + 'hidden_layer_sizes': [(64,), (32, 32)], + 'solver': ['adam'], + 'alpha': [0.0001, 0.001, 0.01], + } + model = model_selection.GridSearchCV( + mlp_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) + model.fit(input_features, is_training_labels) + self.model = model + + def predict(self, input_features): + if self.model is None: + raise AssertionError( + 'Model not trained yet. Please call train_model first.') + return self.model.predict(input_features) + + +class RandomForestAttacker(TrainedAttacker): + """Random forest attacker.""" + + def train_model(self, input_features, is_training_labels): + """Setup a random forest pipeline with cross-validation.""" + rf_model = ensemble.RandomForestClassifier() + + param_grid = { + 'n_estimators': [100], + 'max_features': ['auto', 'sqrt'], + 'max_depth': [5, 10, 20, None], + 'min_samples_split': [2, 5, 10], + 'min_samples_leaf': [1, 2, 4] + } + model = model_selection.GridSearchCV( + rf_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) + model.fit(input_features, is_training_labels) + self.model = model + + def predict(self, input_features): + if self.model is None: + raise AssertionError( + 'Model not trained yet. Please call train_model first.') + return self.model.predict(input_features) + + +class KNearestNeighborsAttacker(TrainedAttacker): + """K nearest neighbor attacker.""" + + def train_model(self, input_features, is_training_labels): + knn_model = neighbors.KNeighborsClassifier() + param_grid = { + 'n_neighbors': [3, 5, 7], + } + model = model_selection.GridSearchCV( + knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) + model.fit(input_features, is_training_labels) + self.model = model + + def predict(self, input_features): + if self.model is None: + raise AssertionError( + 'Model not trained yet. Please call train_model first.') + return self.model.predict(input_features) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py new file mode 100644 index 0000000..e6b9fb6 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py @@ -0,0 +1,59 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures.""" +from absl.testing import absltest +import numpy as np + +from tensorflow_privacy.privacy.membership_inference_attack import models +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData + + +class TrainedAttackerTest(absltest.TestCase): + + def test_base_attacker_train_and_predict(self): + base_attacker = models.TrainedAttacker() + self.assertRaises(NotImplementedError, base_attacker.train_model, [], []) + self.assertRaises(NotImplementedError, base_attacker.predict, []) + + def test_predict_before_training(self): + lr_attacker = models.LogisticRegressionAttacker() + self.assertRaises(AssertionError, lr_attacker.predict, []) + + def test_create_attacker_data_loss_only(self): + attack_input = AttackInputData( + loss_train=np.array([1]), loss_test=np.array([2])) + attacker_data = models.create_attacker_data(attack_input, 0.5) + self.assertLen(attacker_data.features_test, 1) + self.assertLen(attacker_data.features_train, 1) + + def test_create_attacker_data_loss_and_logits(self): + attack_input = AttackInputData( + logits_train=np.array([[1, 2], [5, 6]]), + logits_test=np.array([[10, 11], [14, 15]]), + loss_train=np.array([3, 7]), + loss_test=np.array([12, 16])) + attacker_data = models.create_attacker_data(attack_input, 0.25) + self.assertLen(attacker_data.features_test, 1) + self.assertLen(attacker_data.features_train, 3) + + for i, feature in enumerate(attacker_data.features_train): + self.assertLen(feature, 3) # each feature has two logits and one loss + expected = feature[:2] not in attack_input.logits_train + self.assertEqual(attacker_data.is_training_labels_train[i], expected) + + +if __name__ == '__main__': + absltest.main()