From 0c691d0b4d8ac1591e299f52ccda0f7520e12f15 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Wed, 15 Feb 2023 23:36:52 -0800 Subject: [PATCH] Returns None for getting max results when results are empty. PiperOrigin-RevId: 510054673 --- .../membership_inference_attack/data_structures.py | 14 +++++++++++--- .../data_structures_test.py | 6 ++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index 3e47aad..2d8f3b6 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -1071,8 +1071,10 @@ class AttackResults: slice_dict[slice_str].single_attack_results.append(attack_result) return slice_dict - def get_result_with_max_auc(self) -> SingleAttackResult: + def get_result_with_max_auc(self) -> Optional[SingleAttackResult]: """Get the result with maximum AUC for all attacks and slices.""" + if not self.single_attack_results: + return None aucs = [result.get_auc() for result in self.single_attack_results] if min(aucs) < 0.4: @@ -1081,14 +1083,20 @@ class AttackResults: return self.single_attack_results[np.argmax(aucs)] - def get_result_with_max_attacker_advantage(self) -> SingleAttackResult: + def get_result_with_max_attacker_advantage( + self, + ) -> Optional[SingleAttackResult]: """Get the result with maximum advantage for all attacks and slices.""" + if not self.single_attack_results: + return None return self.single_attack_results[np.argmax([ result.get_attacker_advantage() for result in self.single_attack_results ])] - def get_result_with_max_ppv(self) -> SingleAttackResult: + def get_result_with_max_ppv(self) -> Optional[SingleAttackResult]: """Gets the result with max positive predictive value for all attacks and slices.""" + if not self.single_attack_results: + return None return self.single_attack_results[np.argmax( [result.get_ppv() for result in self.single_attack_results])] diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index feeb3fe..ec99000 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -839,6 +839,12 @@ class AttackResultsTest(absltest.TestCase): }) pd.testing.assert_frame_equal(df, df_expected) + def test_get_max_empty_results(self): + results = AttackResults([]) + self.assertIsNone(results.get_result_with_max_attacker_advantage()) + self.assertIsNone(results.get_result_with_max_auc()) + self.assertIsNone(results.get_result_with_max_ppv()) + if __name__ == '__main__': absltest.main()