Internal change.
PiperOrigin-RevId: 324591262
This commit is contained in:
parent
0a1cbb5b7b
commit
29651216cd
1 changed files with 17 additions and 14 deletions
|
@ -18,6 +18,7 @@ from absl.testing import absltest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
|
|
||||||
|
@ -29,23 +30,20 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]),
|
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]),
|
||||||
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]),
|
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]),
|
||||||
labels_train=np.array([1, 0]),
|
labels_train=np.array([1, 0]),
|
||||||
labels_test=np.array([0, 1])
|
labels_test=np.array([0, 1]))
|
||||||
)
|
|
||||||
|
|
||||||
np.testing.assert_equal(
|
np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2])
|
||||||
attack_input.get_loss_train().tolist(), [0.5, 0.2])
|
np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5])
|
||||||
np.testing.assert_equal(
|
|
||||||
attack_input.get_loss_test().tolist(), [0.2, 0.5])
|
|
||||||
|
|
||||||
def test_get_loss_explicitly_provided(self):
|
def test_get_loss_explicitly_provided(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
loss_train=np.array([1.0, 3.0, 6.0]),
|
loss_train=np.array([1.0, 3.0, 6.0]),
|
||||||
loss_test=np.array([1.0, 4.0, 6.0]))
|
loss_test=np.array([1.0, 4.0, 6.0]))
|
||||||
|
|
||||||
np.testing.assert_equal(
|
np.testing.assert_equal(attack_input.get_loss_train().tolist(),
|
||||||
attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0])
|
[1.0, 3.0, 6.0])
|
||||||
np.testing.assert_equal(
|
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||||
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0])
|
[1.0, 4.0, 6.0])
|
||||||
|
|
||||||
def test_validator(self):
|
def test_validator(self):
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError,
|
||||||
|
@ -60,8 +58,7 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
AttackInputData(labels_test=np.array([])).validate)
|
AttackInputData(labels_test=np.array([])).validate)
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError,
|
||||||
AttackInputData(loss_test=np.array([])).validate)
|
AttackInputData(loss_test=np.array([])).validate)
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError, AttackInputData().validate)
|
||||||
AttackInputData().validate)
|
|
||||||
|
|
||||||
|
|
||||||
class RocCurveTest(absltest.TestCase):
|
class RocCurveTest(absltest.TestCase):
|
||||||
|
@ -108,7 +105,8 @@ class SingleAttackResultTest(absltest.TestCase):
|
||||||
fpr=np.array([0.0, 0.5, 1.0]),
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
thresholds=np.array([0, 1, 2]))
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
result = SingleAttackResult(roc_curve=roc)
|
result = SingleAttackResult(
|
||||||
|
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
self.assertEqual(result.get_auc(), 0.5)
|
self.assertEqual(result.get_auc(), 0.5)
|
||||||
|
|
||||||
|
@ -119,7 +117,8 @@ class SingleAttackResultTest(absltest.TestCase):
|
||||||
fpr=np.array([0.0, 0.5, 1.0]),
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
thresholds=np.array([0, 1, 2]))
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
result = SingleAttackResult(roc_curve=roc)
|
result = SingleAttackResult(
|
||||||
|
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||||
|
|
||||||
|
@ -134,6 +133,8 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
# ROC curve of a perfect classifier
|
# ROC curve of a perfect classifier
|
||||||
self.perfect_classifier_result = SingleAttackResult(
|
self.perfect_classifier_result = SingleAttackResult(
|
||||||
|
slice_spec=None,
|
||||||
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
roc_curve=RocCurve(
|
roc_curve=RocCurve(
|
||||||
tpr=np.array([0.0, 1.0, 1.0]),
|
tpr=np.array([0.0, 1.0, 1.0]),
|
||||||
fpr=np.array([1.0, 1.0, 0.0]),
|
fpr=np.array([1.0, 1.0, 0.0]),
|
||||||
|
@ -141,6 +142,8 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
# ROC curve of a random classifier
|
# ROC curve of a random classifier
|
||||||
self.random_classifier_result = SingleAttackResult(
|
self.random_classifier_result = SingleAttackResult(
|
||||||
|
slice_spec=None,
|
||||||
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
roc_curve=RocCurve(
|
roc_curve=RocCurve(
|
||||||
tpr=np.array([0.0, 0.5, 1.0]),
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
fpr=np.array([0.0, 0.5, 1.0]),
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
|
Loading…
Reference in a new issue