Internal change.

PiperOrigin-RevId: 325249654
This commit is contained in:
A. Unique TensorFlower 2020-08-06 09:47:49 -07:00
parent 08f960a1af
commit efca03b593

View file

@ -14,6 +14,8 @@
# Lint as: python3 # Lint as: python3
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures.""" """Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures."""
import os
import tempfile
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
@ -184,7 +186,7 @@ class AttackResultsTest(absltest.TestCase):
[self.perfect_classifier_result, self.random_classifier_result]) [self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual( self.assertEqual(
results.summary(by_slices=True), 'Highest AUC on slice ' results.summary(by_slices=True), 'Highest AUC on slice '
'SingleSliceSpec' + 'SingleSliceSpec' +
'(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' + '(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' +
'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' + 'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' +
'Highest advantage on ' + 'Highest advantage on ' +
@ -205,6 +207,17 @@ class AttackResultsTest(absltest.TestCase):
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' + 'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0') 'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
def test_save_load(self):
results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result])
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, 'results.pickle')
results.save(filepath)
loaded_results = AttackResults.load(filepath)
self.assertEqual(repr(results), repr(loaded_results))
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()