Internal change.

PiperOrigin-RevId: 324591262
This commit is contained in:
A. Unique TensorFlower 2020-08-03 06:20:40 -07:00 committed by Steve Chien
parent 0a1cbb5b7b
commit 29651216cd

View file

@ -18,6 +18,7 @@ from absl.testing import absltest
import numpy as np import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
@ -29,23 +30,20 @@ class AttackInputDataTest(absltest.TestCase):
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]), logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]),
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]), logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]),
labels_train=np.array([1, 0]), labels_train=np.array([1, 0]),
labels_test=np.array([0, 1]) labels_test=np.array([0, 1]))
)
np.testing.assert_equal( np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2])
attack_input.get_loss_train().tolist(), [0.5, 0.2]) np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5])
np.testing.assert_equal(
attack_input.get_loss_test().tolist(), [0.2, 0.5])
def test_get_loss_explicitly_provided(self): def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData( attack_input = AttackInputData(
loss_train=np.array([1.0, 3.0, 6.0]), loss_train=np.array([1.0, 3.0, 6.0]),
loss_test=np.array([1.0, 4.0, 6.0])) loss_test=np.array([1.0, 4.0, 6.0]))
np.testing.assert_equal( np.testing.assert_equal(attack_input.get_loss_train().tolist(),
attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0]) [1.0, 3.0, 6.0])
np.testing.assert_equal( np.testing.assert_equal(attack_input.get_loss_test().tolist(),
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) [1.0, 4.0, 6.0])
def test_validator(self): def test_validator(self):
self.assertRaises(ValueError, self.assertRaises(ValueError,
@ -60,8 +58,7 @@ class AttackInputDataTest(absltest.TestCase):
AttackInputData(labels_test=np.array([])).validate) AttackInputData(labels_test=np.array([])).validate)
self.assertRaises(ValueError, self.assertRaises(ValueError,
AttackInputData(loss_test=np.array([])).validate) AttackInputData(loss_test=np.array([])).validate)
self.assertRaises(ValueError, self.assertRaises(ValueError, AttackInputData().validate)
AttackInputData().validate)
class RocCurveTest(absltest.TestCase): class RocCurveTest(absltest.TestCase):
@ -108,7 +105,8 @@ class SingleAttackResultTest(absltest.TestCase):
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]))
result = SingleAttackResult(roc_curve=roc) result = SingleAttackResult(
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.get_auc(), 0.5) self.assertEqual(result.get_auc(), 0.5)
@ -119,7 +117,8 @@ class SingleAttackResultTest(absltest.TestCase):
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]))
result = SingleAttackResult(roc_curve=roc) result = SingleAttackResult(
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.get_attacker_advantage(), 0.0) self.assertEqual(result.get_attacker_advantage(), 0.0)
@ -134,6 +133,8 @@ class AttackResultsTest(absltest.TestCase):
# ROC curve of a perfect classifier # ROC curve of a perfect classifier
self.perfect_classifier_result = SingleAttackResult( self.perfect_classifier_result = SingleAttackResult(
slice_spec=None,
attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 1.0, 1.0]), tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]), fpr=np.array([1.0, 1.0, 0.0]),
@ -141,6 +142,8 @@ class AttackResultsTest(absltest.TestCase):
# ROC curve of a random classifier # ROC curve of a random classifier
self.random_classifier_result = SingleAttackResult( self.random_classifier_result = SingleAttackResult(
slice_spec=None,
attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),