From e362f5177351f430b808f25c2c23e7fcb78c0b15 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Wed, 12 Apr 2023 17:13:44 -0700 Subject: [PATCH] Supports slicing for multi-label data. PiperOrigin-RevId: 523846333 --- .../data_structures.py | 19 +++++++++++++------ .../dataset_slicing.py | 16 ++++++++++++---- .../dataset_slicing_test.py | 14 ++++++++++++-- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index 2d8f3b6..d15fdb6 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -84,9 +84,10 @@ class SlicingSpec: # When is set to true, one of the slices is the whole dataset. entire_dataset: bool = True - # Used in classification tasks for slicing by classes. It is assumed that - # classes are integers 0, 1, ... number of classes. When true one slice per - # each class is generated. + # Used in classification tasks for slicing by classes. When true one slice per + # each class is generated. Classes can either be + # - integers 0, 1, ..., (for single label) or + # - an array of integers (for multi-label). by_class: Union[bool, Iterable[int], int] = False # if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%, @@ -238,8 +239,10 @@ class AttackInputData: probs_train: Optional[np.ndarray] = None probs_test: Optional[np.ndarray] = None - # Contains ground-truth classes. Classes are assumed to be integers starting - # from 0. + # Contains ground-truth classes. For single-label classification, classes are + # assumed to be integers starting from 0. For multi-label classification, + # label is assumed to be multi-hot, i.e., labels is a binary array of shape + # (num_examples, num_classes). labels_train: Optional[np.ndarray] = None labels_test: Optional[np.ndarray] = None @@ -290,7 +293,9 @@ class AttackInputData: raise ValueError( 'Can\'t identify the number of classes as no labels were provided. ' 'Please set labels_train and labels_test') - return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1 + if not self.multilabel_data: + return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1 + return self.labels_train.shape[1] @property def logits_or_probs_train(self): @@ -586,6 +591,8 @@ class AttackInputData: _is_array_two_dimensional(self.entropy_test, 'entropy_test') _is_array_two_dimensional(self.labels_train, 'labels_train') _is_array_two_dimensional(self.labels_test, 'labels_test') + self.is_multihot_labels(self.labels_train, 'labels_train') + self.is_multihot_labels(self.labels_test, 'labels_test') else: _is_array_one_dimensional(self.loss_train, 'loss_train') _is_array_one_dimensional(self.loss_test, 'loss_test') diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py index 2678c82..f9e6134 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py @@ -72,10 +72,18 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData: - if data.is_multilabel_data(): - raise ValueError("Slicing by class not supported for multilabel data.") - idx_train = data.labels_train == class_value - idx_test = data.labels_test == class_value + """Gets the indices (boolean) for examples belonging to the given class.""" + if not data.is_multilabel_data(): + idx_train = data.labels_train == class_value + idx_test = data.labels_test == class_value + else: + if class_value >= data.num_classes: + raise ValueError( + f"class_value ({class_value}) is larger than the number of classes" + " (data.num_classes)." + ) + idx_train = data.labels_train[:, class_value].astype(bool) + idx_test = data.labels_test[:, class_value].astype(bool) return _slice_data_by_indices(data, idx_train, idx_test) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py index db1ddca..9ad1172 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py @@ -358,10 +358,20 @@ class GetSliceTestForMultilabelData(absltest.TestCase): expected.slice_spec = entire_dataset_slice self.assertTrue(_are_all_fields_equal(output, self.input_data)) - def test_slice_by_class_fails(self): + def test_slice_by_class(self): class_index = 1 class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index) - self.assertRaises(ValueError, get_slice, self.input_data, class_slice) + output = get_slice(self.input_data, class_slice) + expected_indices_train = np.array([0, 2, 3]) + expected_indices_test = np.array([1, 2]) + + np.testing.assert_array_equal( + output.logits_train, + self.input_data.logits_train[expected_indices_train], + ) + np.testing.assert_array_equal( + output.logits_test, self.input_data.logits_test[expected_indices_test] + ) @mock.patch('logging.Logger.info', wraps=logging.Logger) def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):