Internal change.
PiperOrigin-RevId: 325249654
This commit is contained in:
parent
08f960a1af
commit
efca03b593
1 changed files with 14 additions and 1 deletions
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures."""
|
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
@ -184,7 +186,7 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
[self.perfect_classifier_result, self.random_classifier_result])
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
results.summary(by_slices=True), 'Highest AUC on slice '
|
results.summary(by_slices=True), 'Highest AUC on slice '
|
||||||
'SingleSliceSpec' +
|
'SingleSliceSpec' +
|
||||||
'(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' +
|
'(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' +
|
||||||
'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' +
|
'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' +
|
||||||
'Highest advantage on ' +
|
'Highest advantage on ' +
|
||||||
|
@ -205,6 +207,17 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||||
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||||
|
|
||||||
|
def test_save_load(self):
|
||||||
|
results = AttackResults(
|
||||||
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
filepath = os.path.join(tmpdirname, 'results.pickle')
|
||||||
|
results.save(filepath)
|
||||||
|
loaded_results = AttackResults.load(filepath)
|
||||||
|
|
||||||
|
self.assertEqual(repr(results), repr(loaded_results))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue