Supports slicing for multi-label data.
PiperOrigin-RevId: 523846333
This commit is contained in:
parent
d5e41e20ad
commit
e362f51773
3 changed files with 37 additions and 12 deletions
|
@ -84,9 +84,10 @@ class SlicingSpec:
|
|||
# When is set to true, one of the slices is the whole dataset.
|
||||
entire_dataset: bool = True
|
||||
|
||||
# Used in classification tasks for slicing by classes. It is assumed that
|
||||
# classes are integers 0, 1, ... number of classes. When true one slice per
|
||||
# each class is generated.
|
||||
# Used in classification tasks for slicing by classes. When true one slice per
|
||||
# each class is generated. Classes can either be
|
||||
# - integers 0, 1, ..., (for single label) or
|
||||
# - an array of integers (for multi-label).
|
||||
by_class: Union[bool, Iterable[int], int] = False
|
||||
|
||||
# if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
|
||||
|
@ -238,8 +239,10 @@ class AttackInputData:
|
|||
probs_train: Optional[np.ndarray] = None
|
||||
probs_test: Optional[np.ndarray] = None
|
||||
|
||||
# Contains ground-truth classes. Classes are assumed to be integers starting
|
||||
# from 0.
|
||||
# Contains ground-truth classes. For single-label classification, classes are
|
||||
# assumed to be integers starting from 0. For multi-label classification,
|
||||
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
|
||||
# (num_examples, num_classes).
|
||||
labels_train: Optional[np.ndarray] = None
|
||||
labels_test: Optional[np.ndarray] = None
|
||||
|
||||
|
@ -290,7 +293,9 @@ class AttackInputData:
|
|||
raise ValueError(
|
||||
'Can\'t identify the number of classes as no labels were provided. '
|
||||
'Please set labels_train and labels_test')
|
||||
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
||||
if not self.multilabel_data:
|
||||
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
||||
return self.labels_train.shape[1]
|
||||
|
||||
@property
|
||||
def logits_or_probs_train(self):
|
||||
|
@ -586,6 +591,8 @@ class AttackInputData:
|
|||
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
|
||||
_is_array_two_dimensional(self.labels_train, 'labels_train')
|
||||
_is_array_two_dimensional(self.labels_test, 'labels_test')
|
||||
self.is_multihot_labels(self.labels_train, 'labels_train')
|
||||
self.is_multihot_labels(self.labels_test, 'labels_test')
|
||||
else:
|
||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||
|
|
|
@ -72,10 +72,18 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
|||
|
||||
|
||||
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
||||
if data.is_multilabel_data():
|
||||
raise ValueError("Slicing by class not supported for multilabel data.")
|
||||
idx_train = data.labels_train == class_value
|
||||
idx_test = data.labels_test == class_value
|
||||
"""Gets the indices (boolean) for examples belonging to the given class."""
|
||||
if not data.is_multilabel_data():
|
||||
idx_train = data.labels_train == class_value
|
||||
idx_test = data.labels_test == class_value
|
||||
else:
|
||||
if class_value >= data.num_classes:
|
||||
raise ValueError(
|
||||
f"class_value ({class_value}) is larger than the number of classes"
|
||||
" (data.num_classes)."
|
||||
)
|
||||
idx_train = data.labels_train[:, class_value].astype(bool)
|
||||
idx_test = data.labels_test[:, class_value].astype(bool)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
|
||||
|
||||
|
|
|
@ -358,10 +358,20 @@ class GetSliceTestForMultilabelData(absltest.TestCase):
|
|||
expected.slice_spec = entire_dataset_slice
|
||||
self.assertTrue(_are_all_fields_equal(output, self.input_data))
|
||||
|
||||
def test_slice_by_class_fails(self):
|
||||
def test_slice_by_class(self):
|
||||
class_index = 1
|
||||
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
|
||||
self.assertRaises(ValueError, get_slice, self.input_data, class_slice)
|
||||
output = get_slice(self.input_data, class_slice)
|
||||
expected_indices_train = np.array([0, 2, 3])
|
||||
expected_indices_test = np.array([1, 2])
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
output.logits_train,
|
||||
self.input_data.logits_train[expected_indices_train],
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
output.logits_test, self.input_data.logits_test[expected_indices_test]
|
||||
)
|
||||
|
||||
@mock.patch('logging.Logger.info', wraps=logging.Logger)
|
||||
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):
|
||||
|
|
Loading…
Reference in a new issue