Internal change.
PiperOrigin-RevId: 325249654
This commit is contained in:
parent
08f960a1af
commit
efca03b593
1 changed files with 14 additions and 1 deletions
|
@ -14,6 +14,8 @@
|
|||
|
||||
# Lint as: python3
|
||||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures."""
|
||||
import os
|
||||
import tempfile
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
|
@ -184,7 +186,7 @@ class AttackResultsTest(absltest.TestCase):
|
|||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
self.assertEqual(
|
||||
results.summary(by_slices=True), 'Highest AUC on slice '
|
||||
'SingleSliceSpec' +
|
||||
'SingleSliceSpec' +
|
||||
'(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' +
|
||||
'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' +
|
||||
'Highest advantage on ' +
|
||||
|
@ -205,6 +207,17 @@ class AttackResultsTest(absltest.TestCase):
|
|||
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||
|
||||
def test_save_load(self):
|
||||
results = AttackResults(
|
||||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
filepath = os.path.join(tmpdirname, 'results.pickle')
|
||||
results.save(filepath)
|
||||
loaded_results = AttackResults.load(filepath)
|
||||
|
||||
self.assertEqual(repr(results), repr(loaded_results))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue