Fix for threshold attacks when logits are not provided.

Don't try to compute number of classes when it's not needed.

PiperOrigin-RevId: 344060285
This commit is contained in:
Vadym Doroshenko 2020-11-24 08:05:47 -08:00 committed by A. Unique TensorFlower
parent 35a8096173
commit 15515cb0f4

View file

@ -154,8 +154,10 @@ def run_attacks(attack_input: AttackInputData,
if slicing_spec is None: if slicing_spec is None:
slicing_spec = SlicingSpec(entire_dataset=True) slicing_spec = SlicingSpec(entire_dataset=True)
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes = None
attack_input.num_classes) if slicing_spec.by_class:
num_classes = attack_input.num_classes
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
for single_slice_spec in input_slice_specs: for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec) attack_input_slice = get_slice(attack_input, single_slice_spec)
for attack_type in attack_types: for attack_type in attack_types: