Internal change.
PiperOrigin-RevId: 324574332
This commit is contained in:
parent
d5e34b77c8
commit
0a1cbb5b7b
1 changed files with 74 additions and 0 deletions
|
@ -17,7 +17,9 @@
|
|||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
|
||||
|
||||
class AttackInputDataTest(absltest.TestCase):
|
||||
|
@ -97,5 +99,77 @@ class RocCurveTest(absltest.TestCase):
|
|||
self.assertEqual(roc.get_auc(), 1.0)
|
||||
|
||||
|
||||
class SingleAttackResultTest(absltest.TestCase):
|
||||
|
||||
# Only a basic test, as this method calls RocCurve which is tested separately.
|
||||
def test_auc_random_classifier(self):
|
||||
roc = RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
result = SingleAttackResult(roc_curve=roc)
|
||||
|
||||
self.assertEqual(result.get_auc(), 0.5)
|
||||
|
||||
# Only a basic test, as this method calls RocCurve which is tested separately.
|
||||
def test_attacker_advantage_random_classifier(self):
|
||||
roc = RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
result = SingleAttackResult(roc_curve=roc)
|
||||
|
||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||
|
||||
|
||||
class AttackResultsTest(absltest.TestCase):
|
||||
|
||||
perfect_classifier_result: SingleAttackResult
|
||||
random_classifier_result: SingleAttackResult
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttackResultsTest, self).__init__(*args, **kwargs)
|
||||
|
||||
# ROC curve of a perfect classifier
|
||||
self.perfect_classifier_result = SingleAttackResult(
|
||||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 1.0, 1.0]),
|
||||
fpr=np.array([1.0, 1.0, 0.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
|
||||
# ROC curve of a random classifier
|
||||
self.random_classifier_result = SingleAttackResult(
|
||||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
|
||||
def test_get_result_with_max_auc_first(self):
|
||||
results = AttackResults(
|
||||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
self.assertEqual(results.get_result_with_max_auc(),
|
||||
self.perfect_classifier_result)
|
||||
|
||||
def test_get_result_with_max_auc_second(self):
|
||||
results = AttackResults(
|
||||
[self.random_classifier_result, self.perfect_classifier_result])
|
||||
self.assertEqual(results.get_result_with_max_auc(),
|
||||
self.perfect_classifier_result)
|
||||
|
||||
def test_get_result_with_max_attacker_advantage_first(self):
|
||||
results = AttackResults(
|
||||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||
self.perfect_classifier_result)
|
||||
|
||||
def test_get_result_with_max_attacker_advantage_second(self):
|
||||
results = AttackResults(
|
||||
[self.random_classifier_result, self.perfect_classifier_result])
|
||||
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||
self.perfect_classifier_result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue