Add slice names for custom slices
PiperOrigin-RevId: 544599507
This commit is contained in:
parent
f953e834df
commit
93f5a5249c
2 changed files with 14 additions and 4 deletions
|
@ -20,7 +20,7 @@ import glob
|
|||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Iterable, MutableSequence, Optional, Union, Sequence
|
||||
from typing import Any, Dict, Iterable, MutableSequence, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -116,6 +116,12 @@ class SlicingSpec:
|
|||
all_custom_train_indices: Optional[Sequence[np.ndarray]] = None
|
||||
all_custom_test_indices: Optional[Sequence[np.ndarray]] = None
|
||||
|
||||
# Specifies names for custom slices. The names will be in the output in
|
||||
# SingleAttackResult.slice_spec.value. If provided, the dictionary must
|
||||
# contain names for all custom indices groups from all_custom_train_indices
|
||||
# and all_custom_test_indices.
|
||||
custom_slices_names: Optional[Dict[int, str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.all_custom_train_indices and not self.all_custom_test_indices:
|
||||
return
|
||||
|
|
|
@ -264,9 +264,13 @@ def get_single_slice_specs(
|
|||
f"Too many groups ({groups.size}) for slicing by custom indices. "
|
||||
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
|
||||
for g in groups:
|
||||
result.append(
|
||||
SingleSliceSpec(SlicingFeature.CUSTOM,
|
||||
(custom_train_indices, custom_test_indices, g)))
|
||||
if slicing_spec.custom_slices_names is not None:
|
||||
if g not in slicing_spec.custom_slices_names:
|
||||
raise ValueError(f"Custom slice={g} is not in custom_slices_names")
|
||||
group_id = slicing_spec.custom_slices_names[g]
|
||||
else:
|
||||
group_id = (custom_train_indices, custom_test_indices, g)
|
||||
result.append(SingleSliceSpec(SlicingFeature.CUSTOM, group_id))
|
||||
return result
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue