From f318fbb140423bb460bb53a9ddeb433b292db3d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 24 Jul 2020 08:03:25 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 322996754 --- .../data_structures_test.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 23f9ba0..d602c18 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest import numpy as np from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve class AttackInputDataTest(absltest.TestCase): @@ -61,5 +62,40 @@ class AttackInputDataTest(absltest.TestCase): AttackInputData().validate) +class RocCurveTest(absltest.TestCase): + + def test_auc_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([1.0, 0.5, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_auc(), 0.5) + + def test_auc_perfect_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 1.0, 1.0]), + fpr=np.array([1.0, 1.0, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_auc(), 1.0) + + def test_attacker_advantage_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([1.0, 0.5, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_attacker_advantage(), 0.0) + + def test_attacker_advantage_perfect_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 1.0, 1.0]), + fpr=np.array([1.0, 1.0, 0.0]), + thresholds=np.array([0, 1, 2])) + + self.assertEqual(roc.get_auc(), 1.0) + + if __name__ == '__main__': absltest.main()