Returns None for getting max results when results are empty.
PiperOrigin-RevId: 510054673
This commit is contained in:
parent
13534e5159
commit
0c691d0b4d
2 changed files with 17 additions and 3 deletions
|
@ -1071,8 +1071,10 @@ class AttackResults:
|
||||||
slice_dict[slice_str].single_attack_results.append(attack_result)
|
slice_dict[slice_str].single_attack_results.append(attack_result)
|
||||||
return slice_dict
|
return slice_dict
|
||||||
|
|
||||||
def get_result_with_max_auc(self) -> SingleAttackResult:
|
def get_result_with_max_auc(self) -> Optional[SingleAttackResult]:
|
||||||
"""Get the result with maximum AUC for all attacks and slices."""
|
"""Get the result with maximum AUC for all attacks and slices."""
|
||||||
|
if not self.single_attack_results:
|
||||||
|
return None
|
||||||
aucs = [result.get_auc() for result in self.single_attack_results]
|
aucs = [result.get_auc() for result in self.single_attack_results]
|
||||||
|
|
||||||
if min(aucs) < 0.4:
|
if min(aucs) < 0.4:
|
||||||
|
@ -1081,14 +1083,20 @@ class AttackResults:
|
||||||
|
|
||||||
return self.single_attack_results[np.argmax(aucs)]
|
return self.single_attack_results[np.argmax(aucs)]
|
||||||
|
|
||||||
def get_result_with_max_attacker_advantage(self) -> SingleAttackResult:
|
def get_result_with_max_attacker_advantage(
|
||||||
|
self,
|
||||||
|
) -> Optional[SingleAttackResult]:
|
||||||
"""Get the result with maximum advantage for all attacks and slices."""
|
"""Get the result with maximum advantage for all attacks and slices."""
|
||||||
|
if not self.single_attack_results:
|
||||||
|
return None
|
||||||
return self.single_attack_results[np.argmax([
|
return self.single_attack_results[np.argmax([
|
||||||
result.get_attacker_advantage() for result in self.single_attack_results
|
result.get_attacker_advantage() for result in self.single_attack_results
|
||||||
])]
|
])]
|
||||||
|
|
||||||
def get_result_with_max_ppv(self) -> SingleAttackResult:
|
def get_result_with_max_ppv(self) -> Optional[SingleAttackResult]:
|
||||||
"""Gets the result with max positive predictive value for all attacks and slices."""
|
"""Gets the result with max positive predictive value for all attacks and slices."""
|
||||||
|
if not self.single_attack_results:
|
||||||
|
return None
|
||||||
return self.single_attack_results[np.argmax(
|
return self.single_attack_results[np.argmax(
|
||||||
[result.get_ppv() for result in self.single_attack_results])]
|
[result.get_ppv() for result in self.single_attack_results])]
|
||||||
|
|
||||||
|
|
|
@ -839,6 +839,12 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
})
|
})
|
||||||
pd.testing.assert_frame_equal(df, df_expected)
|
pd.testing.assert_frame_equal(df, df_expected)
|
||||||
|
|
||||||
|
def test_get_max_empty_results(self):
|
||||||
|
results = AttackResults([])
|
||||||
|
self.assertIsNone(results.get_result_with_max_attacker_advantage())
|
||||||
|
self.assertIsNone(results.get_result_with_max_auc())
|
||||||
|
self.assertIsNone(results.get_result_with_max_ppv())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue