From 15515cb0f49376f72062481b351c8b8a03b04887 Mon Sep 17 00:00:00 2001 From: Vadym Doroshenko Date: Tue, 24 Nov 2020 08:05:47 -0800 Subject: [PATCH] Fix for threshold attacks when logits are not provided. Don't try to compute number of classes when it's not needed. PiperOrigin-RevId: 344060285 --- .../membership_inference_attack.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index 3d18648..4df75c8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -154,8 +154,10 @@ def run_attacks(attack_input: AttackInputData, if slicing_spec is None: slicing_spec = SlicingSpec(entire_dataset=True) - input_slice_specs = get_single_slice_specs(slicing_spec, - attack_input.num_classes) + num_classes = None + if slicing_spec.by_class: + num_classes = attack_input.num_classes + input_slice_specs = get_single_slice_specs(slicing_spec, num_classes) for single_slice_spec in input_slice_specs: attack_input_slice = get_slice(attack_input, single_slice_spec) for attack_type in attack_types: