forked from 626_privacy/tensorflow_privacy
Returns None for getting max results when results are empty.
PiperOrigin-RevId: 510054673
This commit is contained in:
parent
13534e5159
commit
0c691d0b4d
2 changed files with 17 additions and 3 deletions
|
@ -1071,8 +1071,10 @@ class AttackResults:
|
|||
slice_dict[slice_str].single_attack_results.append(attack_result)
|
||||
return slice_dict
|
||||
|
||||
def get_result_with_max_auc(self) -> SingleAttackResult:
|
||||
def get_result_with_max_auc(self) -> Optional[SingleAttackResult]:
|
||||
"""Get the result with maximum AUC for all attacks and slices."""
|
||||
if not self.single_attack_results:
|
||||
return None
|
||||
aucs = [result.get_auc() for result in self.single_attack_results]
|
||||
|
||||
if min(aucs) < 0.4:
|
||||
|
@ -1081,14 +1083,20 @@ class AttackResults:
|
|||
|
||||
return self.single_attack_results[np.argmax(aucs)]
|
||||
|
||||
def get_result_with_max_attacker_advantage(self) -> SingleAttackResult:
|
||||
def get_result_with_max_attacker_advantage(
|
||||
self,
|
||||
) -> Optional[SingleAttackResult]:
|
||||
"""Get the result with maximum advantage for all attacks and slices."""
|
||||
if not self.single_attack_results:
|
||||
return None
|
||||
return self.single_attack_results[np.argmax([
|
||||
result.get_attacker_advantage() for result in self.single_attack_results
|
||||
])]
|
||||
|
||||
def get_result_with_max_ppv(self) -> SingleAttackResult:
|
||||
def get_result_with_max_ppv(self) -> Optional[SingleAttackResult]:
|
||||
"""Gets the result with max positive predictive value for all attacks and slices."""
|
||||
if not self.single_attack_results:
|
||||
return None
|
||||
return self.single_attack_results[np.argmax(
|
||||
[result.get_ppv() for result in self.single_attack_results])]
|
||||
|
||||
|
|
|
@ -839,6 +839,12 @@ class AttackResultsTest(absltest.TestCase):
|
|||
})
|
||||
pd.testing.assert_frame_equal(df, df_expected)
|
||||
|
||||
def test_get_max_empty_results(self):
|
||||
results = AttackResults([])
|
||||
self.assertIsNone(results.get_result_with_max_attacker_advantage())
|
||||
self.assertIsNone(results.get_result_with_max_auc())
|
||||
self.assertIsNone(results.get_result_with_max_ppv())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue