diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index 3e47aad..2d8f3b6 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -1071,8 +1071,10 @@ class AttackResults: slice_dict[slice_str].single_attack_results.append(attack_result) return slice_dict - def get_result_with_max_auc(self) -> SingleAttackResult: + def get_result_with_max_auc(self) -> Optional[SingleAttackResult]: """Get the result with maximum AUC for all attacks and slices.""" + if not self.single_attack_results: + return None aucs = [result.get_auc() for result in self.single_attack_results] if min(aucs) < 0.4: @@ -1081,14 +1083,20 @@ class AttackResults: return self.single_attack_results[np.argmax(aucs)] - def get_result_with_max_attacker_advantage(self) -> SingleAttackResult: + def get_result_with_max_attacker_advantage( + self, + ) -> Optional[SingleAttackResult]: """Get the result with maximum advantage for all attacks and slices.""" + if not self.single_attack_results: + return None return self.single_attack_results[np.argmax([ result.get_attacker_advantage() for result in self.single_attack_results ])] - def get_result_with_max_ppv(self) -> SingleAttackResult: + def get_result_with_max_ppv(self) -> Optional[SingleAttackResult]: """Gets the result with max positive predictive value for all attacks and slices.""" + if not self.single_attack_results: + return None return self.single_attack_results[np.argmax( [result.get_ppv() for result in self.single_attack_results])] diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index feeb3fe..ec99000 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -839,6 +839,12 @@ class AttackResultsTest(absltest.TestCase): }) pd.testing.assert_frame_equal(df, df_expected) + def test_get_max_empty_results(self): + results = AttackResults([]) + self.assertIsNone(results.get_result_with_max_attacker_advantage()) + self.assertIsNone(results.get_result_with_max_auc()) + self.assertIsNone(results.get_result_with_max_ppv()) + if __name__ == '__main__': absltest.main()