forked from 626_privacy/tensorflow_privacy
Fix for threshold attacks when logits are not provided.
Don't try to compute number of classes when it's not needed. PiperOrigin-RevId: 344060285
This commit is contained in:
parent
35a8096173
commit
15515cb0f4
1 changed files with 4 additions and 2 deletions
|
@ -154,8 +154,10 @@ def run_attacks(attack_input: AttackInputData,
|
|||
|
||||
if slicing_spec is None:
|
||||
slicing_spec = SlicingSpec(entire_dataset=True)
|
||||
input_slice_specs = get_single_slice_specs(slicing_spec,
|
||||
attack_input.num_classes)
|
||||
num_classes = None
|
||||
if slicing_spec.by_class:
|
||||
num_classes = attack_input.num_classes
|
||||
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
|
||||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
for attack_type in attack_types:
|
||||
|
|
Loading…
Reference in a new issue