diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 38ba60a..3409328 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -18,6 +18,7 @@ from absl.testing import absltest import numpy as np from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult @@ -29,23 +30,20 @@ class AttackInputDataTest(absltest.TestCase): logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]), logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]), labels_train=np.array([1, 0]), - labels_test=np.array([0, 1]) - ) + labels_test=np.array([0, 1])) - np.testing.assert_equal( - attack_input.get_loss_train().tolist(), [0.5, 0.2]) - np.testing.assert_equal( - attack_input.get_loss_test().tolist(), [0.2, 0.5]) + np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2]) + np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5]) def test_get_loss_explicitly_provided(self): attack_input = AttackInputData( loss_train=np.array([1.0, 3.0, 6.0]), loss_test=np.array([1.0, 4.0, 6.0])) - np.testing.assert_equal( - attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0]) - np.testing.assert_equal( - attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + np.testing.assert_equal(attack_input.get_loss_train().tolist(), + [1.0, 3.0, 6.0]) + np.testing.assert_equal(attack_input.get_loss_test().tolist(), + [1.0, 4.0, 6.0]) def test_validator(self): self.assertRaises(ValueError, @@ -60,8 +58,7 @@ class AttackInputDataTest(absltest.TestCase): AttackInputData(labels_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(loss_test=np.array([])).validate) - self.assertRaises(ValueError, - AttackInputData().validate) + self.assertRaises(ValueError, AttackInputData().validate) class RocCurveTest(absltest.TestCase): @@ -108,7 +105,8 @@ class SingleAttackResultTest(absltest.TestCase): fpr=np.array([0.0, 0.5, 1.0]), thresholds=np.array([0, 1, 2])) - result = SingleAttackResult(roc_curve=roc) + result = SingleAttackResult( + roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK) self.assertEqual(result.get_auc(), 0.5) @@ -119,7 +117,8 @@ class SingleAttackResultTest(absltest.TestCase): fpr=np.array([0.0, 0.5, 1.0]), thresholds=np.array([0, 1, 2])) - result = SingleAttackResult(roc_curve=roc) + result = SingleAttackResult( + roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK) self.assertEqual(result.get_attacker_advantage(), 0.0) @@ -134,6 +133,8 @@ class AttackResultsTest(absltest.TestCase): # ROC curve of a perfect classifier self.perfect_classifier_result = SingleAttackResult( + slice_spec=None, + attack_type=AttackType.THRESHOLD_ATTACK, roc_curve=RocCurve( tpr=np.array([0.0, 1.0, 1.0]), fpr=np.array([1.0, 1.0, 0.0]), @@ -141,6 +142,8 @@ class AttackResultsTest(absltest.TestCase): # ROC curve of a random classifier self.random_classifier_result = SingleAttackResult( + slice_spec=None, + attack_type=AttackType.THRESHOLD_ATTACK, roc_curve=RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),