diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index 3d18648..4df75c8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -154,8 +154,10 @@ def run_attacks(attack_input: AttackInputData, if slicing_spec is None: slicing_spec = SlicingSpec(entire_dataset=True) - input_slice_specs = get_single_slice_specs(slicing_spec, - attack_input.num_classes) + num_classes = None + if slicing_spec.by_class: + num_classes = attack_input.num_classes + input_slice_specs = get_single_slice_specs(slicing_spec, num_classes) for single_slice_spec in input_slice_specs: attack_input_slice = get_slice(attack_input, single_slice_spec) for attack_type in attack_types: