Internal change.
PiperOrigin-RevId: 324574332
This commit is contained in:
parent
d5e34b77c8
commit
0a1cbb5b7b
1 changed files with 74 additions and 0 deletions
|
@ -17,7 +17,9 @@
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
|
|
||||||
|
|
||||||
class AttackInputDataTest(absltest.TestCase):
|
class AttackInputDataTest(absltest.TestCase):
|
||||||
|
@ -97,5 +99,77 @@ class RocCurveTest(absltest.TestCase):
|
||||||
self.assertEqual(roc.get_auc(), 1.0)
|
self.assertEqual(roc.get_auc(), 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleAttackResultTest(absltest.TestCase):
|
||||||
|
|
||||||
|
# Only a basic test, as this method calls RocCurve which is tested separately.
|
||||||
|
def test_auc_random_classifier(self):
|
||||||
|
roc = RocCurve(
|
||||||
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
|
result = SingleAttackResult(roc_curve=roc)
|
||||||
|
|
||||||
|
self.assertEqual(result.get_auc(), 0.5)
|
||||||
|
|
||||||
|
# Only a basic test, as this method calls RocCurve which is tested separately.
|
||||||
|
def test_attacker_advantage_random_classifier(self):
|
||||||
|
roc = RocCurve(
|
||||||
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
|
result = SingleAttackResult(roc_curve=roc)
|
||||||
|
|
||||||
|
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
|
perfect_classifier_result: SingleAttackResult
|
||||||
|
random_classifier_result: SingleAttackResult
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttackResultsTest, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# ROC curve of a perfect classifier
|
||||||
|
self.perfect_classifier_result = SingleAttackResult(
|
||||||
|
roc_curve=RocCurve(
|
||||||
|
tpr=np.array([0.0, 1.0, 1.0]),
|
||||||
|
fpr=np.array([1.0, 1.0, 0.0]),
|
||||||
|
thresholds=np.array([0, 1, 2])))
|
||||||
|
|
||||||
|
# ROC curve of a random classifier
|
||||||
|
self.random_classifier_result = SingleAttackResult(
|
||||||
|
roc_curve=RocCurve(
|
||||||
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
thresholds=np.array([0, 1, 2])))
|
||||||
|
|
||||||
|
def test_get_result_with_max_auc_first(self):
|
||||||
|
results = AttackResults(
|
||||||
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
|
self.assertEqual(results.get_result_with_max_auc(),
|
||||||
|
self.perfect_classifier_result)
|
||||||
|
|
||||||
|
def test_get_result_with_max_auc_second(self):
|
||||||
|
results = AttackResults(
|
||||||
|
[self.random_classifier_result, self.perfect_classifier_result])
|
||||||
|
self.assertEqual(results.get_result_with_max_auc(),
|
||||||
|
self.perfect_classifier_result)
|
||||||
|
|
||||||
|
def test_get_result_with_max_attacker_advantage_first(self):
|
||||||
|
results = AttackResults(
|
||||||
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
|
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||||
|
self.perfect_classifier_result)
|
||||||
|
|
||||||
|
def test_get_result_with_max_attacker_advantage_second(self):
|
||||||
|
results = AttackResults(
|
||||||
|
[self.random_classifier_result, self.perfect_classifier_result])
|
||||||
|
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||||
|
self.perfect_classifier_result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue