Add slice names for custom slices

PiperOrigin-RevId: 544599507
This commit is contained in:
Vadym Doroshenko 2023-06-30 02:32:29 -07:00 committed by A. Unique TensorFlower
parent f953e834df
commit 93f5a5249c
2 changed files with 14 additions and 4 deletions

View file

@ -20,7 +20,7 @@ import glob
import logging
import os
import pickle
from typing import Any, Iterable, MutableSequence, Optional, Union, Sequence
from typing import Any, Dict, Iterable, MutableSequence, Optional, Sequence, Union
import numpy as np
import pandas as pd
@ -116,6 +116,12 @@ class SlicingSpec:
all_custom_train_indices: Optional[Sequence[np.ndarray]] = None
all_custom_test_indices: Optional[Sequence[np.ndarray]] = None
# Specifies names for custom slices. The names will be in the output in
# SingleAttackResult.slice_spec.value. If provided, the dictionary must
# contain names for all custom indices groups from all_custom_train_indices
# and all_custom_test_indices.
custom_slices_names: Optional[Dict[int, str]] = None
def __post_init__(self):
if not self.all_custom_train_indices and not self.all_custom_test_indices:
return

View file

@ -264,9 +264,13 @@ def get_single_slice_specs(
f"Too many groups ({groups.size}) for slicing by custom indices. "
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
for g in groups:
result.append(
SingleSliceSpec(SlicingFeature.CUSTOM,
(custom_train_indices, custom_test_indices, g)))
if slicing_spec.custom_slices_names is not None:
if g not in slicing_spec.custom_slices_names:
raise ValueError(f"Custom slice={g} is not in custom_slices_names")
group_id = slicing_spec.custom_slices_names[g]
else:
group_id = (custom_train_indices, custom_test_indices, g)
result.append(SingleSliceSpec(SlicingFeature.CUSTOM, group_id))
return result