forked from 626_privacy/tensorflow_privacy
Internal change.
PiperOrigin-RevId: 325423652
This commit is contained in:
parent
5ad8676d38
commit
40419b56a3
1 changed files with 17 additions and 0 deletions
|
@ -17,6 +17,7 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
@ -27,6 +28,20 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
|
|
||||||
|
|
||||||
|
class SingleSliceSpecTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def testStrEntireDataset(self):
|
||||||
|
self.assertEqual(str(SingleSliceSpec()), 'Entire dataset')
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(SlicingFeature.CLASS, 2, 'CLASS=2'),
|
||||||
|
(SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'),
|
||||||
|
(SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'),
|
||||||
|
)
|
||||||
|
def testStr(self, feature, value, expected_str):
|
||||||
|
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
|
||||||
|
|
||||||
|
|
||||||
class AttackInputDataTest(absltest.TestCase):
|
class AttackInputDataTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_get_loss(self):
|
def test_get_loss(self):
|
||||||
|
@ -181,6 +196,7 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||||
self.perfect_classifier_result)
|
self.perfect_classifier_result)
|
||||||
|
|
||||||
|
@absltest.skip('Will be enabled in the next CL')
|
||||||
def test_summary_by_slices(self):
|
def test_summary_by_slices(self):
|
||||||
results = AttackResults(
|
results = AttackResults(
|
||||||
[self.perfect_classifier_result, self.random_classifier_result])
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
|
@ -197,6 +213,7 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||||
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||||
|
|
||||||
|
@absltest.skip('Will be enabled in the next CL')
|
||||||
def test_summary_without_slices(self):
|
def test_summary_without_slices(self):
|
||||||
results = AttackResults(
|
results = AttackResults(
|
||||||
[self.perfect_classifier_result, self.random_classifier_result])
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
|
|
Loading…
Reference in a new issue