Returns None for getting max results when results are empty.

PiperOrigin-RevId: 510054673
This commit is contained in:
Shuang Song 2023-02-15 23:36:52 -08:00 committed by A. Unique TensorFlower
parent 13534e5159
commit 0c691d0b4d
2 changed files with 17 additions and 3 deletions

View file

@ -1071,8 +1071,10 @@ class AttackResults:
slice_dict[slice_str].single_attack_results.append(attack_result) slice_dict[slice_str].single_attack_results.append(attack_result)
return slice_dict return slice_dict
def get_result_with_max_auc(self) -> SingleAttackResult: def get_result_with_max_auc(self) -> Optional[SingleAttackResult]:
"""Get the result with maximum AUC for all attacks and slices.""" """Get the result with maximum AUC for all attacks and slices."""
if not self.single_attack_results:
return None
aucs = [result.get_auc() for result in self.single_attack_results] aucs = [result.get_auc() for result in self.single_attack_results]
if min(aucs) < 0.4: if min(aucs) < 0.4:
@ -1081,14 +1083,20 @@ class AttackResults:
return self.single_attack_results[np.argmax(aucs)] return self.single_attack_results[np.argmax(aucs)]
def get_result_with_max_attacker_advantage(self) -> SingleAttackResult: def get_result_with_max_attacker_advantage(
self,
) -> Optional[SingleAttackResult]:
"""Get the result with maximum advantage for all attacks and slices.""" """Get the result with maximum advantage for all attacks and slices."""
if not self.single_attack_results:
return None
return self.single_attack_results[np.argmax([ return self.single_attack_results[np.argmax([
result.get_attacker_advantage() for result in self.single_attack_results result.get_attacker_advantage() for result in self.single_attack_results
])] ])]
def get_result_with_max_ppv(self) -> SingleAttackResult: def get_result_with_max_ppv(self) -> Optional[SingleAttackResult]:
"""Gets the result with max positive predictive value for all attacks and slices.""" """Gets the result with max positive predictive value for all attacks and slices."""
if not self.single_attack_results:
return None
return self.single_attack_results[np.argmax( return self.single_attack_results[np.argmax(
[result.get_ppv() for result in self.single_attack_results])] [result.get_ppv() for result in self.single_attack_results])]

View file

@ -839,6 +839,12 @@ class AttackResultsTest(absltest.TestCase):
}) })
pd.testing.assert_frame_equal(df, df_expected) pd.testing.assert_frame_equal(df, df_expected)
def test_get_max_empty_results(self):
results = AttackResults([])
self.assertIsNone(results.get_result_with_max_attacker_advantage())
self.assertIsNone(results.get_result_with_max_auc())
self.assertIsNone(results.get_result_with_max_ppv())
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()