Internal change.

PiperOrigin-RevId: 322996754
This commit is contained in:
A. Unique TensorFlower 2020-07-24 08:03:25 -07:00
parent 267ea7f90d
commit f318fbb140

View file

@ -17,6 +17,7 @@
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
class AttackInputDataTest(absltest.TestCase): class AttackInputDataTest(absltest.TestCase):
@ -61,5 +62,40 @@ class AttackInputDataTest(absltest.TestCase):
AttackInputData().validate) AttackInputData().validate)
class RocCurveTest(absltest.TestCase):
def test_auc_random_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([1.0, 0.5, 0.0]),
thresholds=np.array([0, 1, 2]))
self.assertEqual(roc.get_auc(), 0.5)
def test_auc_perfect_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2]))
self.assertEqual(roc.get_auc(), 1.0)
def test_attacker_advantage_random_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([1.0, 0.5, 0.0]),
thresholds=np.array([0, 1, 2]))
self.assertEqual(roc.get_attacker_advantage(), 0.0)
def test_attacker_advantage_perfect_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2]))
self.assertEqual(roc.get_auc(), 1.0)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()