Internal change.
PiperOrigin-RevId: 322996754
This commit is contained in:
parent
267ea7f90d
commit
f318fbb140
1 changed files with 36 additions and 0 deletions
|
@ -17,6 +17,7 @@
|
|||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
|
||||
|
||||
class AttackInputDataTest(absltest.TestCase):
|
||||
|
@ -61,5 +62,40 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
AttackInputData().validate)
|
||||
|
||||
|
||||
class RocCurveTest(absltest.TestCase):
|
||||
|
||||
def test_auc_random_classifier(self):
|
||||
roc = RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([1.0, 0.5, 0.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
self.assertEqual(roc.get_auc(), 0.5)
|
||||
|
||||
def test_auc_perfect_classifier(self):
|
||||
roc = RocCurve(
|
||||
tpr=np.array([0.0, 1.0, 1.0]),
|
||||
fpr=np.array([1.0, 1.0, 0.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
self.assertEqual(roc.get_auc(), 1.0)
|
||||
|
||||
def test_attacker_advantage_random_classifier(self):
|
||||
roc = RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([1.0, 0.5, 0.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
self.assertEqual(roc.get_attacker_advantage(), 0.0)
|
||||
|
||||
def test_attacker_advantage_perfect_classifier(self):
|
||||
roc = RocCurve(
|
||||
tpr=np.array([0.0, 1.0, 1.0]),
|
||||
fpr=np.array([1.0, 1.0, 0.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
self.assertEqual(roc.get_auc(), 1.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue