Internal change.
PiperOrigin-RevId: 324591262
This commit is contained in:
parent
0a1cbb5b7b
commit
29651216cd
1 changed files with 17 additions and 14 deletions
|
@ -18,6 +18,7 @@ from absl.testing import absltest
|
|||
import numpy as np
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
|
||||
|
@ -29,23 +30,20 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]),
|
||||
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]),
|
||||
labels_train=np.array([1, 0]),
|
||||
labels_test=np.array([0, 1])
|
||||
)
|
||||
labels_test=np.array([0, 1]))
|
||||
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_train().tolist(), [0.5, 0.2])
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_test().tolist(), [0.2, 0.5])
|
||||
np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2])
|
||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5])
|
||||
|
||||
def test_get_loss_explicitly_provided(self):
|
||||
attack_input = AttackInputData(
|
||||
loss_train=np.array([1.0, 3.0, 6.0]),
|
||||
loss_test=np.array([1.0, 4.0, 6.0]))
|
||||
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0])
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0])
|
||||
np.testing.assert_equal(attack_input.get_loss_train().tolist(),
|
||||
[1.0, 3.0, 6.0])
|
||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||
[1.0, 4.0, 6.0])
|
||||
|
||||
def test_validator(self):
|
||||
self.assertRaises(ValueError,
|
||||
|
@ -60,8 +58,7 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
AttackInputData(labels_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(loss_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData().validate)
|
||||
self.assertRaises(ValueError, AttackInputData().validate)
|
||||
|
||||
|
||||
class RocCurveTest(absltest.TestCase):
|
||||
|
@ -108,7 +105,8 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
result = SingleAttackResult(roc_curve=roc)
|
||||
result = SingleAttackResult(
|
||||
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertEqual(result.get_auc(), 0.5)
|
||||
|
||||
|
@ -119,7 +117,8 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
result = SingleAttackResult(roc_curve=roc)
|
||||
result = SingleAttackResult(
|
||||
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||
|
||||
|
@ -134,6 +133,8 @@ class AttackResultsTest(absltest.TestCase):
|
|||
|
||||
# ROC curve of a perfect classifier
|
||||
self.perfect_classifier_result = SingleAttackResult(
|
||||
slice_spec=None,
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 1.0, 1.0]),
|
||||
fpr=np.array([1.0, 1.0, 0.0]),
|
||||
|
@ -141,6 +142,8 @@ class AttackResultsTest(absltest.TestCase):
|
|||
|
||||
# ROC curve of a random classifier
|
||||
self.random_classifier_result = SingleAttackResult(
|
||||
slice_spec=None,
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
|
|
Loading…
Reference in a new issue