Add slice names for custom slices
PiperOrigin-RevId: 544599507
This commit is contained in:
parent
f953e834df
commit
93f5a5249c
2 changed files with 14 additions and 4 deletions
|
@ -20,7 +20,7 @@ import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Iterable, MutableSequence, Optional, Union, Sequence
|
from typing import Any, Dict, Iterable, MutableSequence, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -116,6 +116,12 @@ class SlicingSpec:
|
||||||
all_custom_train_indices: Optional[Sequence[np.ndarray]] = None
|
all_custom_train_indices: Optional[Sequence[np.ndarray]] = None
|
||||||
all_custom_test_indices: Optional[Sequence[np.ndarray]] = None
|
all_custom_test_indices: Optional[Sequence[np.ndarray]] = None
|
||||||
|
|
||||||
|
# Specifies names for custom slices. The names will be in the output in
|
||||||
|
# SingleAttackResult.slice_spec.value. If provided, the dictionary must
|
||||||
|
# contain names for all custom indices groups from all_custom_train_indices
|
||||||
|
# and all_custom_test_indices.
|
||||||
|
custom_slices_names: Optional[Dict[int, str]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.all_custom_train_indices and not self.all_custom_test_indices:
|
if not self.all_custom_train_indices and not self.all_custom_test_indices:
|
||||||
return
|
return
|
||||||
|
|
|
@ -264,9 +264,13 @@ def get_single_slice_specs(
|
||||||
f"Too many groups ({groups.size}) for slicing by custom indices. "
|
f"Too many groups ({groups.size}) for slicing by custom indices. "
|
||||||
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
|
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
|
||||||
for g in groups:
|
for g in groups:
|
||||||
result.append(
|
if slicing_spec.custom_slices_names is not None:
|
||||||
SingleSliceSpec(SlicingFeature.CUSTOM,
|
if g not in slicing_spec.custom_slices_names:
|
||||||
(custom_train_indices, custom_test_indices, g)))
|
raise ValueError(f"Custom slice={g} is not in custom_slices_names")
|
||||||
|
group_id = slicing_spec.custom_slices_names[g]
|
||||||
|
else:
|
||||||
|
group_id = (custom_train_indices, custom_test_indices, g)
|
||||||
|
result.append(SingleSliceSpec(SlicingFeature.CUSTOM, group_id))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue