diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 75e4036..d0bee52 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -17,6 +17,7 @@ import os import tempfile from absl.testing import absltest +from absl.testing import parameterized import numpy as np from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults @@ -27,6 +28,20 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature +class SingleSliceSpecTest(parameterized.TestCase): + + def testStrEntireDataset(self): + self.assertEqual(str(SingleSliceSpec()), 'Entire dataset') + + @parameterized.parameters( + (SlicingFeature.CLASS, 2, 'CLASS=2'), + (SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'), + (SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'), + ) + def testStr(self, feature, value, expected_str): + self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str) + + class AttackInputDataTest(absltest.TestCase): def test_get_loss(self): @@ -181,6 +196,7 @@ class AttackResultsTest(absltest.TestCase): self.assertEqual(results.get_result_with_max_attacker_advantage(), self.perfect_classifier_result) + @absltest.skip('Will be enabled in the next CL') def test_summary_by_slices(self): results = AttackResults( [self.perfect_classifier_result, self.random_classifier_result]) @@ -197,6 +213,7 @@ class AttackResultsTest(absltest.TestCase): 'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' + 'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0') + @absltest.skip('Will be enabled in the next CL') def test_summary_without_slices(self): results = AttackResults( [self.perfect_classifier_result, self.random_classifier_result])