diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 3409328..0f9ffbe 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -21,6 +21,8 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature class AttackInputDataTest(absltest.TestCase): @@ -106,7 +108,9 @@ class SingleAttackResultTest(absltest.TestCase): thresholds=np.array([0, 1, 2])) result = SingleAttackResult( - roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK) + roc_curve=roc, + slice_spec=SingleSliceSpec(None), + attack_type=AttackType.THRESHOLD_ATTACK) self.assertEqual(result.get_auc(), 0.5) @@ -118,7 +122,9 @@ class SingleAttackResultTest(absltest.TestCase): thresholds=np.array([0, 1, 2])) result = SingleAttackResult( - roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK) + roc_curve=roc, + slice_spec=SingleSliceSpec(None), + attack_type=AttackType.THRESHOLD_ATTACK) self.assertEqual(result.get_attacker_advantage(), 0.0) @@ -133,7 +139,7 @@ class AttackResultsTest(absltest.TestCase): # ROC curve of a perfect classifier self.perfect_classifier_result = SingleAttackResult( - slice_spec=None, + slice_spec=SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True), attack_type=AttackType.THRESHOLD_ATTACK, roc_curve=RocCurve( tpr=np.array([0.0, 1.0, 1.0]), @@ -142,7 +148,7 @@ class AttackResultsTest(absltest.TestCase): # ROC curve of a random classifier self.random_classifier_result = SingleAttackResult( - slice_spec=None, + slice_spec=SingleSliceSpec(None), attack_type=AttackType.THRESHOLD_ATTACK, roc_curve=RocCurve( tpr=np.array([0.0, 0.5, 1.0]), @@ -173,6 +179,32 @@ class AttackResultsTest(absltest.TestCase): self.assertEqual(results.get_result_with_max_attacker_advantage(), self.perfect_classifier_result) + def test_summary_by_slices(self): + results = AttackResults( + [self.perfect_classifier_result, self.random_classifier_result]) + self.assertEqual( + results.summary(by_slices=True), 'Highest AUC on slice ' + 'SingleSliceSpec' + + '(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' + + 'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' + + 'Highest advantage on ' + + 'slice SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED=True) ' + + 'achieved by AttackType.THRESHOLD_ATTACK with an advantage of 1.0\n' + + 'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' + + 'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' + + 'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' + + 'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0') + + def test_summary_without_slices(self): + results = AttackResults( + [self.perfect_classifier_result, self.random_classifier_result]) + self.assertEqual( + results.summary(by_slices=False), + 'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' + + 'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' + + 'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' + + 'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0') + if __name__ == '__main__': absltest.main()