Internal change.

PiperOrigin-RevId: 324574332
This commit is contained in:
A. Unique TensorFlower 2020-08-03 03:49:26 -07:00 committed by Steve Chien
parent d5e34b77c8
commit 0a1cbb5b7b

View file

@ -17,7 +17,9 @@
from absl.testing import absltest
import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
class AttackInputDataTest(absltest.TestCase):
@ -97,5 +99,77 @@ class RocCurveTest(absltest.TestCase):
self.assertEqual(roc.get_auc(), 1.0)
class SingleAttackResultTest(absltest.TestCase):
# Only a basic test, as this method calls RocCurve which is tested separately.
def test_auc_random_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2]))
result = SingleAttackResult(roc_curve=roc)
self.assertEqual(result.get_auc(), 0.5)
# Only a basic test, as this method calls RocCurve which is tested separately.
def test_attacker_advantage_random_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2]))
result = SingleAttackResult(roc_curve=roc)
self.assertEqual(result.get_attacker_advantage(), 0.0)
class AttackResultsTest(absltest.TestCase):
perfect_classifier_result: SingleAttackResult
random_classifier_result: SingleAttackResult
def __init__(self, *args, **kwargs):
super(AttackResultsTest, self).__init__(*args, **kwargs)
# ROC curve of a perfect classifier
self.perfect_classifier_result = SingleAttackResult(
roc_curve=RocCurve(
tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2])))
# ROC curve of a random classifier
self.random_classifier_result = SingleAttackResult(
roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])))
def test_get_result_with_max_auc_first(self):
results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual(results.get_result_with_max_auc(),
self.perfect_classifier_result)
def test_get_result_with_max_auc_second(self):
results = AttackResults(
[self.random_classifier_result, self.perfect_classifier_result])
self.assertEqual(results.get_result_with_max_auc(),
self.perfect_classifier_result)
def test_get_result_with_max_attacker_advantage_first(self):
results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual(results.get_result_with_max_attacker_advantage(),
self.perfect_classifier_result)
def test_get_result_with_max_attacker_advantage_second(self):
results = AttackResults(
[self.random_classifier_result, self.perfect_classifier_result])
self.assertEqual(results.get_result_with_max_attacker_advantage(),
self.perfect_classifier_result)
if __name__ == '__main__':
absltest.main()