diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 0f9ffbe..75e4036 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -14,6 +14,8 @@ # Lint as: python3 """Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures.""" +import os +import tempfile from absl.testing import absltest import numpy as np from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData @@ -184,7 +186,7 @@ class AttackResultsTest(absltest.TestCase): [self.perfect_classifier_result, self.random_classifier_result]) self.assertEqual( results.summary(by_slices=True), 'Highest AUC on slice ' - 'SingleSliceSpec' + + 'SingleSliceSpec' + '(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' + 'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' + 'Highest advantage on ' + @@ -205,6 +207,17 @@ class AttackResultsTest(absltest.TestCase): 'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' + 'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0') + def test_save_load(self): + results = AttackResults( + [self.perfect_classifier_result, self.random_classifier_result]) + + with tempfile.TemporaryDirectory() as tmpdirname: + filepath = os.path.join(tmpdirname, 'results.pickle') + results.save(filepath) + loaded_results = AttackResults.load(filepath) + + self.assertEqual(repr(results), repr(loaded_results)) + if __name__ == '__main__': absltest.main()