diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 23f9ba0..d602c18 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest import numpy as np from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve class AttackInputDataTest(absltest.TestCase): @@ -61,5 +62,40 @@ class AttackInputDataTest(absltest.TestCase): AttackInputData().validate) +class RocCurveTest(absltest.TestCase): + + def test_auc_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([1.0, 0.5, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_auc(), 0.5) + + def test_auc_perfect_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 1.0, 1.0]), + fpr=np.array([1.0, 1.0, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_auc(), 1.0) + + def test_attacker_advantage_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([1.0, 0.5, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_attacker_advantage(), 0.0) + + def test_attacker_advantage_perfect_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 1.0, 1.0]), + fpr=np.array([1.0, 1.0, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_auc(), 1.0) + + if __name__ == '__main__': absltest.main()