diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index a763ebe..38ba60a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -17,7 +17,9 @@ from absl.testing import absltest import numpy as np from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult class AttackInputDataTest(absltest.TestCase): @@ -97,5 +99,77 @@ class RocCurveTest(absltest.TestCase): self.assertEqual(roc.get_auc(), 1.0) +class SingleAttackResultTest(absltest.TestCase): + + # Only a basic test, as this method calls RocCurve which is tested separately. + def test_auc_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([0.0, 0.5, 1.0]), + thresholds=np.array([0, 1, 2])) + + result = SingleAttackResult(roc_curve=roc) + + self.assertEqual(result.get_auc(), 0.5) + + # Only a basic test, as this method calls RocCurve which is tested separately. + def test_attacker_advantage_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([0.0, 0.5, 1.0]), + thresholds=np.array([0, 1, 2])) + + result = SingleAttackResult(roc_curve=roc) + + self.assertEqual(result.get_attacker_advantage(), 0.0) + + +class AttackResultsTest(absltest.TestCase): + + perfect_classifier_result: SingleAttackResult + random_classifier_result: SingleAttackResult + + def __init__(self, *args, **kwargs): + super(AttackResultsTest, self).__init__(*args, **kwargs) + + # ROC curve of a perfect classifier + self.perfect_classifier_result = SingleAttackResult( + roc_curve=RocCurve( + tpr=np.array([0.0, 1.0, 1.0]), + fpr=np.array([1.0, 1.0, 0.0]), + thresholds=np.array([0, 1, 2]))) + + # ROC curve of a random classifier + self.random_classifier_result = SingleAttackResult( + roc_curve=RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([0.0, 0.5, 1.0]), + thresholds=np.array([0, 1, 2]))) + + def test_get_result_with_max_auc_first(self): + results = AttackResults( + [self.perfect_classifier_result, self.random_classifier_result]) + self.assertEqual(results.get_result_with_max_auc(), + self.perfect_classifier_result) + + def test_get_result_with_max_auc_second(self): + results = AttackResults( + [self.random_classifier_result, self.perfect_classifier_result]) + self.assertEqual(results.get_result_with_max_auc(), + self.perfect_classifier_result) + + def test_get_result_with_max_attacker_advantage_first(self): + results = AttackResults( + [self.perfect_classifier_result, self.random_classifier_result]) + self.assertEqual(results.get_result_with_max_attacker_advantage(), + self.perfect_classifier_result) + + def test_get_result_with_max_attacker_advantage_second(self): + results = AttackResults( + [self.random_classifier_result, self.perfect_classifier_result]) + self.assertEqual(results.get_result_with_max_attacker_advantage(), + self.perfect_classifier_result) + + if __name__ == '__main__': absltest.main()