From 45da453410ffa078b2d05dc4883d006d578e1b6d Mon Sep 17 00:00:00 2001 From: Vadym Doroshenko Date: Fri, 16 Jun 2023 08:22:07 -0700 Subject: [PATCH] Implement possibility to return slice indices. PiperOrigin-RevId: 540885025 --- .../codelabs/example.py | 13 ++- .../data_structures.py | 5 ++ .../dataset_slicing.py | 90 +++++++++++++------ .../dataset_slicing_test.py | 18 +++- .../membership_inference_attack.py | 28 ++++-- .../membership_inference_attack_test.py | 17 ++++ 6 files changed, 129 insertions(+), 42 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/example.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/example.py index 0bf1004..dbdb55d 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/example.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/example.py @@ -148,11 +148,16 @@ def main(unused_argv): labels_train=training_labels, labels_test=test_labels, probs_train=training_pred, - probs_test=test_pred), + probs_test=test_pred, + ), data_structures.SlicingSpec(entire_dataset=True, by_class=True), - attack_types=(data_structures.AttackType.THRESHOLD_ATTACK, - data_structures.AttackType.LOGISTIC_REGRESSION), - privacy_report_metadata=privacy_report_metadata) + attack_types=( + data_structures.AttackType.THRESHOLD_ATTACK, + data_structures.AttackType.LOGISTIC_REGRESSION, + ), + privacy_report_metadata=privacy_report_metadata, + return_slice_indices=True, + ) epoch_results.append(attack_results) # Generate privacy reports diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index d15fdb6..3e4e8d3 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -766,6 +766,11 @@ class SingleAttackResult: # test set samples will have lower scores than the training set samples. membership_scores_test: Optional[np.ndarray] = None + # Indices of train and test examples from the input data that were used in + # this attack. + train_indices: Optional[np.ndarray] = None + test_indices: Optional[np.ndarray] = None + def get_attacker_advantage(self): return self.roc_curve.get_attacker_advantage() diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py index f9e6134..95b609a 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py @@ -33,8 +33,12 @@ def _slice_if_not_none(a, idx): return None if a is None else a[idx] -def _slice_data_by_indices(data: AttackInputData, idx_train, - idx_test) -> AttackInputData: +def _slice_data_by_indices( + data: AttackInputData, + idx_train: np.ndarray, + idx_test: np.ndarray, + return_slice_indices: bool, +) -> AttackInputData: """Slices train fields with idx_train and test fields with idx_test.""" result = AttackInputData() @@ -68,10 +72,16 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, # of the original dataset. result.multilabel_data = data.is_multilabel_data() + if return_slice_indices: + result.train_indices = np.where(idx_train)[0] + result.test_indices = np.where(idx_test)[0] + return result -def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData: +def _slice_by_class( + data: AttackInputData, class_value: int, return_slice_indices: bool = False +) -> AttackInputData: """Gets the indices (boolean) for examples belonging to the given class.""" if not data.is_multilabel_data(): idx_train = data.labels_train == class_value @@ -84,11 +94,15 @@ def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData: ) idx_train = data.labels_train[:, class_value].astype(bool) idx_test = data.labels_test[:, class_value].astype(bool) - return _slice_data_by_indices(data, idx_train, idx_test) + return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices) -def _slice_by_percentiles(data: AttackInputData, from_percentile: float, - to_percentile: float): +def _slice_by_percentiles( + data: AttackInputData, + from_percentile: float, + to_percentile: float, + return_slice_indices: bool = False, +) -> AttackInputData: """Slices samples by loss percentiles.""" # Find from_percentile and to_percentile percentiles in losses. @@ -107,22 +121,31 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float, idx_train = (from_loss <= loss_train) & (loss_train <= to_loss) idx_test = (from_loss <= loss_test) & (loss_test <= to_loss) - return _slice_data_by_indices(data, idx_train, idx_test) + return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices) -def _indices_by_classification(logits_or_probs, labels, correctly_classified): +def _indices_by_classification( + logits_or_probs, + labels, + correctly_classified, +): idx_correct = labels == np.argmax(logits_or_probs, axis=1) return idx_correct if correctly_classified else np.invert(idx_correct) -def _slice_by_classification_correctness(data: AttackInputData, - correctly_classified: bool): +def _slice_by_classification_correctness( + data: AttackInputData, + correctly_classified: bool, + return_slice_indices: bool = False, +) -> AttackInputData: """Slices attack inputs by whether they were classified correctly. Args: data: Data to be used as input to the attack models. correctly_classified: Whether to use the indices corresponding to the correctly classified samples. + return_slice_indices: if true, the returned AttackInputData will include + indices of the train and test data samples that were used for this slice. Returns: AttackInputData object containing the sliced data. @@ -136,13 +159,16 @@ def _slice_by_classification_correctness(data: AttackInputData, correctly_classified) idx_test = _indices_by_classification(data.logits_or_probs_test, data.labels_test, correctly_classified) - return _slice_data_by_indices(data, idx_train, idx_test) + return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices) -def _slice_by_custom_indices(data: AttackInputData, - custom_train_indices: np.ndarray, - custom_test_indices: np.ndarray, - group_value: int) -> AttackInputData: +def _slice_by_custom_indices( + data: AttackInputData, + custom_train_indices: np.ndarray, + custom_test_indices: np.ndarray, + group_value: int, + return_slice_indices: bool = False, +) -> AttackInputData: """Slices attack inputs by custom indices. Args: @@ -150,6 +176,8 @@ def _slice_by_custom_indices(data: AttackInputData, custom_train_indices: The group indices of each training example. custom_test_indices: The group indices of each test example. group_value: The group value to pick. + return_slice_indices: if true, the returned AttackInputData will include + indices of the train and test data samples that were used for this slice. Returns: AttackInputData object containing the sliced data. @@ -167,12 +195,12 @@ def _slice_by_custom_indices(data: AttackInputData, f"{test_size}") idx_train = custom_train_indices == group_value idx_test = custom_test_indices == group_value - return _slice_data_by_indices(data, idx_train, idx_test) + return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices) def get_single_slice_specs( - slicing_spec: SlicingSpec, - num_classes: Optional[int] = None) -> List[SingleSliceSpec]: + slicing_spec: SlicingSpec, num_classes: Optional[int] = None +) -> List[SingleSliceSpec]: """Returns slices of data according to slicing_spec. Args: @@ -242,22 +270,34 @@ def get_single_slice_specs( return result -def get_slice(data: AttackInputData, - slice_spec: SingleSliceSpec) -> AttackInputData: +def get_slice( + data: AttackInputData, + slice_spec: SingleSliceSpec, + return_slice_indices: bool = False, +) -> AttackInputData: """Returns a single slice of data according to slice_spec.""" if slice_spec.entire_dataset: data_slice = copy.copy(data) elif slice_spec.feature == SlicingFeature.CLASS: - data_slice = _slice_by_class(data, slice_spec.value) + data_slice = _slice_by_class(data, slice_spec.value, return_slice_indices) elif slice_spec.feature == SlicingFeature.PERCENTILE: from_percentile, to_percentile = slice_spec.value - data_slice = _slice_by_percentiles(data, from_percentile, to_percentile) + data_slice = _slice_by_percentiles( + data, from_percentile, to_percentile, return_slice_indices + ) elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED: - data_slice = _slice_by_classification_correctness(data, slice_spec.value) + data_slice = _slice_by_classification_correctness( + data, slice_spec.value, return_slice_indices + ) elif slice_spec.feature == SlicingFeature.CUSTOM: custom_train_indices, custom_test_indices, group_value = slice_spec.value - data_slice = _slice_by_custom_indices(data, custom_train_indices, - custom_test_indices, group_value) + data_slice = _slice_by_custom_indices( + data, + custom_train_indices, + custom_test_indices, + group_value, + return_slice_indices, + ) else: raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py index 9ad1172..f7214e0 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py @@ -267,10 +267,12 @@ class GetSliceTest(parameterized.TestCase): self.assertTrue((output.labels_train == [1, 0, 2]).all()) self.assertTrue((output.labels_test == [1]).all()) - def test_slice_by_correctness(self): - percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, - False) - output = get_slice(self.input_data, percentile_slice) + @parameterized.parameters(False, True) + def test_slice_by_correctness(self, return_slice_indices): + slice_spec = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False) + output = get_slice( + self.input_data, slice_spec, return_slice_indices=return_slice_indices + ) # Check logits. self.assertLen(output.logits_train, 2) @@ -284,6 +286,14 @@ class GetSliceTest(parameterized.TestCase): self.assertTrue((output.labels_train == [0, 2]).all()) self.assertTrue((output.labels_test == [1, 2, 0]).all()) + # Check return indices + if return_slice_indices: + self.assertTrue((output.train_indices == [1, 3]).all()) + self.assertTrue((output.test_indices == [0, 1, 2]).all()) + else: + self.assertFalse(hasattr(output, 'train_indices')) + self.assertFalse(hasattr(output, 'test_indices')) + def test_slice_by_custom_indices(self): custom_train_indices = np.array([2, 2, 100, 4]) custom_test_indices = np.array([100, 2, 2, 2]) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index 1552b5e..ea6bc98 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -254,14 +254,16 @@ def _run_attack(attack_input: AttackInputData, return _run_threshold_attack(attack_input) -def run_attacks(attack_input: AttackInputData, - slicing_spec: SlicingSpec = None, - attack_types: Iterable[AttackType] = ( - AttackType.THRESHOLD_ATTACK,), - privacy_report_metadata: PrivacyReportMetadata = None, - balance_attacker_training: bool = True, - min_num_samples: int = 1, - backend: Optional[str] = None) -> AttackResults: +def run_attacks( + attack_input: AttackInputData, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), + privacy_report_metadata: PrivacyReportMetadata = None, + balance_attacker_training: bool = True, + min_num_samples: int = 1, + backend: Optional[str] = None, + return_slice_indices=False, +) -> AttackResults: """Runs membership inference attacks on a classification model. It runs attacks specified by attack_types on each attack_input slice which is @@ -282,6 +284,9 @@ def run_attacks(attack_input: AttackInputData, may not support multiprocessing and in those cases the `threading` backend should be used. See https://joblib.readthedocs.io/en/latest/parallel.html for more details. + return_slice_indices: if true, the result for each slice will include the + indices of train and test data examples that were used for this slice and + attacks. This does not return indices for the "whole dataset" slice. Returns: the attack result. @@ -301,7 +306,9 @@ def run_attacks(attack_input: AttackInputData, logging.info('Will run %s attacks on each of %s slice specifications.', num_attacks, num_slice_specs) for single_slice_spec in input_slice_specs: - attack_input_slice = get_slice(attack_input, single_slice_spec) + attack_input_slice = get_slice( + attack_input, single_slice_spec, return_slice_indices + ) for attack_type in attack_types: logging.info('Running attack: %s', attack_type.name) attack_result = _run_attack(attack_input_slice, attack_type, @@ -313,6 +320,9 @@ def run_attacks(attack_input: AttackInputData, 'positive predictive value=%s', attack_type.name, attack_result.get_auc(), attack_result.get_attacker_advantage(), attack_result.get_ppv()) + if return_slice_indices and not single_slice_spec.entire_dataset: + attack_result.train_indices = attack_input_slice.train_indices + attack_result.test_indices = attack_input_slice.test_indices attack_results.append(attack_result) privacy_report_metadata = _compute_missing_privacy_report_metadata( diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py index 730df27..73fac0c 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py @@ -445,6 +445,23 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase): AttackType.THRESHOLD_ATTACK.value, ) + def test_run_attacks_size_return_indices(self): + result = mia.run_attacks( + get_test_input(100, 100), + SlicingSpec( + entire_dataset=False, + by_class=True, + by_percentiles=True, + by_classification_correctness=True, + ), + (AttackType.LOGISTIC_REGRESSION,), + return_slice_indices=True, + ) + + for attack_result in result.single_attack_results: + self.assertIsNotNone(attack_result.train_indices) + self.assertIsNotNone(attack_result.test_indices) + if __name__ == '__main__': absltest.main()