forked from 626_privacy/tensorflow_privacy
Internal change.
PiperOrigin-RevId: 322996754
This commit is contained in:
parent
267ea7f90d
commit
f318fbb140
1 changed files with 36 additions and 0 deletions
|
@ -17,6 +17,7 @@
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
|
|
||||||
|
|
||||||
class AttackInputDataTest(absltest.TestCase):
|
class AttackInputDataTest(absltest.TestCase):
|
||||||
|
@ -61,5 +62,40 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
AttackInputData().validate)
|
AttackInputData().validate)
|
||||||
|
|
||||||
|
|
||||||
|
class RocCurveTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_auc_random_classifier(self):
|
||||||
|
roc = RocCurve(
|
||||||
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
fpr=np.array([1.0, 0.5, 0.0]),
|
||||||
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
|
self.assertEqual(roc.get_auc(), 0.5)
|
||||||
|
|
||||||
|
def test_auc_perfect_classifier(self):
|
||||||
|
roc = RocCurve(
|
||||||
|
tpr=np.array([0.0, 1.0, 1.0]),
|
||||||
|
fpr=np.array([1.0, 1.0, 0.0]),
|
||||||
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
|
self.assertEqual(roc.get_auc(), 1.0)
|
||||||
|
|
||||||
|
def test_attacker_advantage_random_classifier(self):
|
||||||
|
roc = RocCurve(
|
||||||
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
fpr=np.array([1.0, 0.5, 0.0]),
|
||||||
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
|
self.assertEqual(roc.get_attacker_advantage(), 0.0)
|
||||||
|
|
||||||
|
def test_attacker_advantage_perfect_classifier(self):
|
||||||
|
roc = RocCurve(
|
||||||
|
tpr=np.array([0.0, 1.0, 1.0]),
|
||||||
|
fpr=np.array([1.0, 1.0, 0.0]),
|
||||||
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
|
self.assertEqual(roc.get_auc(), 1.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue