Fix for threshold attacks when logits are not provided.
Don't try to compute number of classes when it's not needed. PiperOrigin-RevId: 344060285
This commit is contained in:
parent
35a8096173
commit
15515cb0f4
1 changed files with 4 additions and 2 deletions
|
@ -154,8 +154,10 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
|
|
||||||
if slicing_spec is None:
|
if slicing_spec is None:
|
||||||
slicing_spec = SlicingSpec(entire_dataset=True)
|
slicing_spec = SlicingSpec(entire_dataset=True)
|
||||||
input_slice_specs = get_single_slice_specs(slicing_spec,
|
num_classes = None
|
||||||
attack_input.num_classes)
|
if slicing_spec.by_class:
|
||||||
|
num_classes = attack_input.num_classes
|
||||||
|
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
|
||||||
for single_slice_spec in input_slice_specs:
|
for single_slice_spec in input_slice_specs:
|
||||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||||
for attack_type in attack_types:
|
for attack_type in attack_types:
|
||||||
|
|
Loading…
Reference in a new issue