Supports slicing for multi-label data.
PiperOrigin-RevId: 523846333
This commit is contained in:
parent
d5e41e20ad
commit
e362f51773
3 changed files with 37 additions and 12 deletions
|
@ -84,9 +84,10 @@ class SlicingSpec:
|
||||||
# When is set to true, one of the slices is the whole dataset.
|
# When is set to true, one of the slices is the whole dataset.
|
||||||
entire_dataset: bool = True
|
entire_dataset: bool = True
|
||||||
|
|
||||||
# Used in classification tasks for slicing by classes. It is assumed that
|
# Used in classification tasks for slicing by classes. When true one slice per
|
||||||
# classes are integers 0, 1, ... number of classes. When true one slice per
|
# each class is generated. Classes can either be
|
||||||
# each class is generated.
|
# - integers 0, 1, ..., (for single label) or
|
||||||
|
# - an array of integers (for multi-label).
|
||||||
by_class: Union[bool, Iterable[int], int] = False
|
by_class: Union[bool, Iterable[int], int] = False
|
||||||
|
|
||||||
# if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
|
# if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
|
||||||
|
@ -238,8 +239,10 @@ class AttackInputData:
|
||||||
probs_train: Optional[np.ndarray] = None
|
probs_train: Optional[np.ndarray] = None
|
||||||
probs_test: Optional[np.ndarray] = None
|
probs_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
# Contains ground-truth classes. Classes are assumed to be integers starting
|
# Contains ground-truth classes. For single-label classification, classes are
|
||||||
# from 0.
|
# assumed to be integers starting from 0. For multi-label classification,
|
||||||
|
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
|
||||||
|
# (num_examples, num_classes).
|
||||||
labels_train: Optional[np.ndarray] = None
|
labels_train: Optional[np.ndarray] = None
|
||||||
labels_test: Optional[np.ndarray] = None
|
labels_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
@ -290,7 +293,9 @@ class AttackInputData:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Can\'t identify the number of classes as no labels were provided. '
|
'Can\'t identify the number of classes as no labels were provided. '
|
||||||
'Please set labels_train and labels_test')
|
'Please set labels_train and labels_test')
|
||||||
|
if not self.multilabel_data:
|
||||||
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
||||||
|
return self.labels_train.shape[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logits_or_probs_train(self):
|
def logits_or_probs_train(self):
|
||||||
|
@ -586,6 +591,8 @@ class AttackInputData:
|
||||||
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
|
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
|
||||||
_is_array_two_dimensional(self.labels_train, 'labels_train')
|
_is_array_two_dimensional(self.labels_train, 'labels_train')
|
||||||
_is_array_two_dimensional(self.labels_test, 'labels_test')
|
_is_array_two_dimensional(self.labels_test, 'labels_test')
|
||||||
|
self.is_multihot_labels(self.labels_train, 'labels_train')
|
||||||
|
self.is_multihot_labels(self.labels_test, 'labels_test')
|
||||||
else:
|
else:
|
||||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||||
|
|
|
@ -72,10 +72,18 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||||
|
|
||||||
|
|
||||||
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
||||||
if data.is_multilabel_data():
|
"""Gets the indices (boolean) for examples belonging to the given class."""
|
||||||
raise ValueError("Slicing by class not supported for multilabel data.")
|
if not data.is_multilabel_data():
|
||||||
idx_train = data.labels_train == class_value
|
idx_train = data.labels_train == class_value
|
||||||
idx_test = data.labels_test == class_value
|
idx_test = data.labels_test == class_value
|
||||||
|
else:
|
||||||
|
if class_value >= data.num_classes:
|
||||||
|
raise ValueError(
|
||||||
|
f"class_value ({class_value}) is larger than the number of classes"
|
||||||
|
" (data.num_classes)."
|
||||||
|
)
|
||||||
|
idx_train = data.labels_train[:, class_value].astype(bool)
|
||||||
|
idx_test = data.labels_test[:, class_value].astype(bool)
|
||||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -358,10 +358,20 @@ class GetSliceTestForMultilabelData(absltest.TestCase):
|
||||||
expected.slice_spec = entire_dataset_slice
|
expected.slice_spec = entire_dataset_slice
|
||||||
self.assertTrue(_are_all_fields_equal(output, self.input_data))
|
self.assertTrue(_are_all_fields_equal(output, self.input_data))
|
||||||
|
|
||||||
def test_slice_by_class_fails(self):
|
def test_slice_by_class(self):
|
||||||
class_index = 1
|
class_index = 1
|
||||||
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
|
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
|
||||||
self.assertRaises(ValueError, get_slice, self.input_data, class_slice)
|
output = get_slice(self.input_data, class_slice)
|
||||||
|
expected_indices_train = np.array([0, 2, 3])
|
||||||
|
expected_indices_test = np.array([1, 2])
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
output.logits_train,
|
||||||
|
self.input_data.logits_train[expected_indices_train],
|
||||||
|
)
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
output.logits_test, self.input_data.logits_test[expected_indices_test]
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch('logging.Logger.info', wraps=logging.Logger)
|
@mock.patch('logging.Logger.info', wraps=logging.Logger)
|
||||||
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):
|
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):
|
||||||
|
|
Loading…
Reference in a new issue