diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index 3e4e8d3..93f87f0 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -20,7 +20,7 @@ import glob import logging import os import pickle -from typing import Any, Iterable, MutableSequence, Optional, Union, Sequence +from typing import Any, Dict, Iterable, MutableSequence, Optional, Sequence, Union import numpy as np import pandas as pd @@ -116,6 +116,12 @@ class SlicingSpec: all_custom_train_indices: Optional[Sequence[np.ndarray]] = None all_custom_test_indices: Optional[Sequence[np.ndarray]] = None + # Specifies names for custom slices. The names will be in the output in + # SingleAttackResult.slice_spec.value. If provided, the dictionary must + # contain names for all custom indices groups from all_custom_train_indices + # and all_custom_test_indices. + custom_slices_names: Optional[Dict[int, str]] = None + def __post_init__(self): if not self.all_custom_train_indices and not self.all_custom_test_indices: return diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py index 95b609a..ab486d4 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py @@ -264,9 +264,13 @@ def get_single_slice_specs( f"Too many groups ({groups.size}) for slicing by custom indices. " f"Should be no more than {_MAX_NUM_OF_SLICES}.") for g in groups: - result.append( - SingleSliceSpec(SlicingFeature.CUSTOM, - (custom_train_indices, custom_test_indices, g))) + if slicing_spec.custom_slices_names is not None: + if g not in slicing_spec.custom_slices_names: + raise ValueError(f"Custom slice={g} is not in custom_slices_names") + group_id = slicing_spec.custom_slices_names[g] + else: + group_id = (custom_train_indices, custom_test_indices, g) + result.append(SingleSliceSpec(SlicingFeature.CUSTOM, group_id)) return result