Supports slicing for multi-label data.

PiperOrigin-RevId: 523846333
This commit is contained in:
Shuang Song 2023-04-12 17:13:44 -07:00 committed by A. Unique TensorFlower
parent d5e41e20ad
commit e362f51773
3 changed files with 37 additions and 12 deletions

View file

@ -84,9 +84,10 @@ class SlicingSpec:
# When is set to true, one of the slices is the whole dataset.
entire_dataset: bool = True
# Used in classification tasks for slicing by classes. It is assumed that
# classes are integers 0, 1, ... number of classes. When true one slice per
# each class is generated.
# Used in classification tasks for slicing by classes. When true one slice per
# each class is generated. Classes can either be
# - integers 0, 1, ..., (for single label) or
# - an array of integers (for multi-label).
by_class: Union[bool, Iterable[int], int] = False
# if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
@ -238,8 +239,10 @@ class AttackInputData:
probs_train: Optional[np.ndarray] = None
probs_test: Optional[np.ndarray] = None
# Contains ground-truth classes. Classes are assumed to be integers starting
# from 0.
# Contains ground-truth classes. For single-label classification, classes are
# assumed to be integers starting from 0. For multi-label classification,
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
# (num_examples, num_classes).
labels_train: Optional[np.ndarray] = None
labels_test: Optional[np.ndarray] = None
@ -290,7 +293,9 @@ class AttackInputData:
raise ValueError(
'Can\'t identify the number of classes as no labels were provided. '
'Please set labels_train and labels_test')
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
if not self.multilabel_data:
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
return self.labels_train.shape[1]
@property
def logits_or_probs_train(self):
@ -586,6 +591,8 @@ class AttackInputData:
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
_is_array_two_dimensional(self.labels_train, 'labels_train')
_is_array_two_dimensional(self.labels_test, 'labels_test')
self.is_multihot_labels(self.labels_train, 'labels_train')
self.is_multihot_labels(self.labels_test, 'labels_test')
else:
_is_array_one_dimensional(self.loss_train, 'loss_train')
_is_array_one_dimensional(self.loss_test, 'loss_test')

View file

@ -72,10 +72,18 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
if data.is_multilabel_data():
raise ValueError("Slicing by class not supported for multilabel data.")
idx_train = data.labels_train == class_value
idx_test = data.labels_test == class_value
"""Gets the indices (boolean) for examples belonging to the given class."""
if not data.is_multilabel_data():
idx_train = data.labels_train == class_value
idx_test = data.labels_test == class_value
else:
if class_value >= data.num_classes:
raise ValueError(
f"class_value ({class_value}) is larger than the number of classes"
" (data.num_classes)."
)
idx_train = data.labels_train[:, class_value].astype(bool)
idx_test = data.labels_test[:, class_value].astype(bool)
return _slice_data_by_indices(data, idx_train, idx_test)

View file

@ -358,10 +358,20 @@ class GetSliceTestForMultilabelData(absltest.TestCase):
expected.slice_spec = entire_dataset_slice
self.assertTrue(_are_all_fields_equal(output, self.input_data))
def test_slice_by_class_fails(self):
def test_slice_by_class(self):
class_index = 1
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
self.assertRaises(ValueError, get_slice, self.input_data, class_slice)
output = get_slice(self.input_data, class_slice)
expected_indices_train = np.array([0, 2, 3])
expected_indices_test = np.array([1, 2])
np.testing.assert_array_equal(
output.logits_train,
self.input_data.logits_train[expected_indices_train],
)
np.testing.assert_array_equal(
output.logits_test, self.input_data.logits_test[expected_indices_test]
)
@mock.patch('logging.Logger.info', wraps=logging.Logger)
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):