forked from 626_privacy/tensorflow_privacy
Internal change.
PiperOrigin-RevId: 325210305
This commit is contained in:
parent
e91c820b2a
commit
08f960a1af
1 changed files with 36 additions and 4 deletions
|
@ -21,6 +21,8 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
|
||||
|
||||
class AttackInputDataTest(absltest.TestCase):
|
||||
|
@ -106,7 +108,9 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
result = SingleAttackResult(
|
||||
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
roc_curve=roc,
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertEqual(result.get_auc(), 0.5)
|
||||
|
||||
|
@ -118,7 +122,9 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
thresholds=np.array([0, 1, 2]))
|
||||
|
||||
result = SingleAttackResult(
|
||||
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
roc_curve=roc,
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||
|
||||
|
@ -133,7 +139,7 @@ class AttackResultsTest(absltest.TestCase):
|
|||
|
||||
# ROC curve of a perfect classifier
|
||||
self.perfect_classifier_result = SingleAttackResult(
|
||||
slice_spec=None,
|
||||
slice_spec=SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 1.0, 1.0]),
|
||||
|
@ -142,7 +148,7 @@ class AttackResultsTest(absltest.TestCase):
|
|||
|
||||
# ROC curve of a random classifier
|
||||
self.random_classifier_result = SingleAttackResult(
|
||||
slice_spec=None,
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
|
@ -173,6 +179,32 @@ class AttackResultsTest(absltest.TestCase):
|
|||
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||
self.perfect_classifier_result)
|
||||
|
||||
def test_summary_by_slices(self):
|
||||
results = AttackResults(
|
||||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
self.assertEqual(
|
||||
results.summary(by_slices=True), 'Highest AUC on slice '
|
||||
'SingleSliceSpec' +
|
||||
'(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' +
|
||||
'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' +
|
||||
'Highest advantage on ' +
|
||||
'slice SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED=True) ' +
|
||||
'achieved by AttackType.THRESHOLD_ATTACK with an advantage of 1.0\n' +
|
||||
'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||
'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' +
|
||||
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||
|
||||
def test_summary_without_slices(self):
|
||||
results = AttackResults(
|
||||
[self.perfect_classifier_result, self.random_classifier_result])
|
||||
self.assertEqual(
|
||||
results.summary(by_slices=False),
|
||||
'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||
'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' +
|
||||
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue