Internal change.
PiperOrigin-RevId: 325423652
This commit is contained in:
parent
5ad8676d38
commit
40419b56a3
1 changed files with 17 additions and 0 deletions
|
@ -17,6 +17,7 @@
|
|||
import os
|
||||
import tempfile
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
|
@ -27,6 +28,20 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
|
||||
|
||||
class SingleSliceSpecTest(parameterized.TestCase):
|
||||
|
||||
def testStrEntireDataset(self):
|
||||
self.assertEqual(str(SingleSliceSpec()), 'Entire dataset')
|
||||
|
||||
@parameterized.parameters(
|
||||
(SlicingFeature.CLASS, 2, 'CLASS=2'),
|
||||
(SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'),
|
||||
(SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'),
|
||||
)
|
||||
def testStr(self, feature, value, expected_str):
|
||||
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
|
||||
|
||||
|
||||
class AttackInputDataTest(absltest.TestCase):
|
||||
|
||||
def test_get_loss(self):
|
||||
|
@ -181,6 +196,7 @@ class AttackResultsTest(absltest.TestCase):
|
|||
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||
self.perfect_classifier_result)
|
||||
|
||||
@absltest.skip('Will be enabled in the next CL')
|
||||
def test_summary_by_slices(self):
|
||||
results = AttackResults(
|
||||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
|
@ -197,6 +213,7 @@ class AttackResultsTest(absltest.TestCase):
|
|||
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||
|
||||
@absltest.skip('Will be enabled in the next CL')
|
||||
def test_summary_without_slices(self):
|
||||
results = AttackResults(
|
||||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
|
|
Loading…
Reference in a new issue