Internal change.

PiperOrigin-RevId: 324591262
This commit is contained in:
A. Unique TensorFlower 2020-08-03 06:20:40 -07:00 committed by Steve Chien
parent 0a1cbb5b7b
commit 29651216cd

View file

@ -18,6 +18,7 @@ from absl.testing import absltest
import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
@ -29,23 +30,20 @@ class AttackInputDataTest(absltest.TestCase):
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]),
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]),
labels_train=np.array([1, 0]),
labels_test=np.array([0, 1])
)
labels_test=np.array([0, 1]))
np.testing.assert_equal(
attack_input.get_loss_train().tolist(), [0.5, 0.2])
np.testing.assert_equal(
attack_input.get_loss_test().tolist(), [0.2, 0.5])
np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2])
np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5])
def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData(
loss_train=np.array([1.0, 3.0, 6.0]),
loss_test=np.array([1.0, 4.0, 6.0]))
np.testing.assert_equal(
attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0])
np.testing.assert_equal(
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0])
np.testing.assert_equal(attack_input.get_loss_train().tolist(),
[1.0, 3.0, 6.0])
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
[1.0, 4.0, 6.0])
def test_validator(self):
self.assertRaises(ValueError,
@ -60,8 +58,7 @@ class AttackInputDataTest(absltest.TestCase):
AttackInputData(labels_test=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(loss_test=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData().validate)
self.assertRaises(ValueError, AttackInputData().validate)
class RocCurveTest(absltest.TestCase):
@ -108,7 +105,8 @@ class SingleAttackResultTest(absltest.TestCase):
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2]))
result = SingleAttackResult(roc_curve=roc)
result = SingleAttackResult(
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.get_auc(), 0.5)
@ -119,7 +117,8 @@ class SingleAttackResultTest(absltest.TestCase):
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2]))
result = SingleAttackResult(roc_curve=roc)
result = SingleAttackResult(
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.get_attacker_advantage(), 0.0)
@ -134,6 +133,8 @@ class AttackResultsTest(absltest.TestCase):
# ROC curve of a perfect classifier
self.perfect_classifier_result = SingleAttackResult(
slice_spec=None,
attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve(
tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]),
@ -141,6 +142,8 @@ class AttackResultsTest(absltest.TestCase):
# ROC curve of a random classifier
self.random_classifier_result = SingleAttackResult(
slice_spec=None,
attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),