From 930c4d13e87e870cc84982a37eb544f9e846d592 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Apr 2022 11:36:21 -0700 Subject: [PATCH] Handle the case when the data comes from a multilabel classification problem but the provided samples happen to have just one positive label per sample. PiperOrigin-RevId: 445468067 --- .../data_structures.py | 11 ++++++ .../data_structures_test.py | 35 +++++++++++++++++++ .../membership_inference_attack.py | 8 +++++ 3 files changed, 54 insertions(+) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index 36812d5..02658b1 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -235,6 +235,13 @@ class AttackInputData: # corresponding class is absent from the example, and 1s where the # corresponding class is present. multilabel_data: Optional[bool] = None + # In some corner cases, the provided data comes from a multi-label + # classification model, but the samples all happen to have just 1 label. In + # that case, the `is_multilabel_data()` test will return a `False` value. The + # attack models will expect 1D input, which will throw an exception. Handle + # this case by letting the user set a flag that forces the input data to be + # treated as multilabel data. + force_multilabel_data: bool = False @property def num_classes(self): @@ -450,6 +457,10 @@ class AttackInputData: Raises: ValueError if the dimensionality of the train and test data are not equal. """ + # If 'force_multilabel_data' is set, then assume multilabel data going + # forward. + if self.force_multilabel_data: + self.multilabel_data = True # If the data has already been checked for multihot encoded labels, then # return the result of the evaluation. if self.multilabel_data is not None: diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index 80d7a74..f807a62 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -358,6 +358,41 @@ class AttackInputDataTest(parameterized.TestCase): np.testing.assert_equal(attack_input.get_loss_test().tolist(), np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]])) + def test_validate_with_force_multilabel_false(self): + attack_input = AttackInputData( + probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), + probs_test=np.array([[0.8, 0.7, 0.9]]), + labels_train=np.array([[0, 0, 1], [0, 1, 0]]), + labels_test=np.array([[1, 0, 0]])) + self.assertRaisesRegex(ValueError, + r'should be a one dimensional numpy array.', + attack_input.validate) + + def test_validate_with_force_multilabel_true(self): + attack_input = AttackInputData( + probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), + probs_test=np.array([[0.8, 0.7, 0.9]]), + labels_train=np.array([[0, 0, 1], [0, 1, 0]]), + labels_test=np.array([[1, 0, 0]]), + force_multilabel_data=True) + try: + attack_input.validate() + except ValueError: + # For a 'ValueError' exception the test should record a failure. All + # other exceptions are errors. + self.fail('ValueError not raised by validate().') + + def test_multilabel_data_true_with_force_multilabel_true(self): + attack_input = AttackInputData( + probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), + probs_test=np.array([[0.8, 0.7, 0.9]]), + labels_train=np.array([[0, 0, 1], [0, 1, 0]]), + labels_test=np.array([[1, 0, 0]]), + force_multilabel_data=True) + self.assertTrue( + attack_input.multilabel_data, + '"force_multilabel_data" is True but "multilabel_data" is False.') + class RocCurveTest(absltest.TestCase): diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index 5fe1149..1660497 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -217,6 +217,7 @@ def run_attacks(attack_input: AttackInputData, """ attack_input.validate() attack_results = [] + attack_types = list(attack_types) if slicing_spec is None: slicing_spec = SlicingSpec(entire_dataset=True) @@ -224,6 +225,10 @@ def run_attacks(attack_input: AttackInputData, if slicing_spec.by_class: num_classes = attack_input.num_classes input_slice_specs = get_single_slice_specs(slicing_spec, num_classes) + num_slice_specs = len(input_slice_specs) + num_attacks = len(attack_types) + logging.info('Will run %s attacks on each of %s slice specifications.', + num_attacks, num_slice_specs) for single_slice_spec in input_slice_specs: attack_input_slice = get_slice(attack_input, single_slice_spec) for attack_type in attack_types: @@ -231,6 +236,9 @@ def run_attacks(attack_input: AttackInputData, attack_result = _run_attack(attack_input_slice, attack_type, balance_attacker_training, min_num_samples) if attack_result is not None: + logging.info('%s attack had an AUC=%s and attacker advantage=%s', + attack_type.name, attack_result.get_auc(), + attack_result.get_attacker_advantage()) attack_results.append(attack_result) privacy_report_metadata = _compute_missing_privacy_report_metadata(