Internal change.
PiperOrigin-RevId: 325210305
This commit is contained in:
parent
e91c820b2a
commit
08f960a1af
1 changed files with 36 additions and 4 deletions
|
@ -21,6 +21,8 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
|
|
||||||
|
|
||||||
class AttackInputDataTest(absltest.TestCase):
|
class AttackInputDataTest(absltest.TestCase):
|
||||||
|
@ -106,7 +108,9 @@ class SingleAttackResultTest(absltest.TestCase):
|
||||||
thresholds=np.array([0, 1, 2]))
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
result = SingleAttackResult(
|
result = SingleAttackResult(
|
||||||
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
roc_curve=roc,
|
||||||
|
slice_spec=SingleSliceSpec(None),
|
||||||
|
attack_type=AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
self.assertEqual(result.get_auc(), 0.5)
|
self.assertEqual(result.get_auc(), 0.5)
|
||||||
|
|
||||||
|
@ -118,7 +122,9 @@ class SingleAttackResultTest(absltest.TestCase):
|
||||||
thresholds=np.array([0, 1, 2]))
|
thresholds=np.array([0, 1, 2]))
|
||||||
|
|
||||||
result = SingleAttackResult(
|
result = SingleAttackResult(
|
||||||
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK)
|
roc_curve=roc,
|
||||||
|
slice_spec=SingleSliceSpec(None),
|
||||||
|
attack_type=AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||||
|
|
||||||
|
@ -133,7 +139,7 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
# ROC curve of a perfect classifier
|
# ROC curve of a perfect classifier
|
||||||
self.perfect_classifier_result = SingleAttackResult(
|
self.perfect_classifier_result = SingleAttackResult(
|
||||||
slice_spec=None,
|
slice_spec=SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True),
|
||||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
roc_curve=RocCurve(
|
roc_curve=RocCurve(
|
||||||
tpr=np.array([0.0, 1.0, 1.0]),
|
tpr=np.array([0.0, 1.0, 1.0]),
|
||||||
|
@ -142,7 +148,7 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
# ROC curve of a random classifier
|
# ROC curve of a random classifier
|
||||||
self.random_classifier_result = SingleAttackResult(
|
self.random_classifier_result = SingleAttackResult(
|
||||||
slice_spec=None,
|
slice_spec=SingleSliceSpec(None),
|
||||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
roc_curve=RocCurve(
|
roc_curve=RocCurve(
|
||||||
tpr=np.array([0.0, 0.5, 1.0]),
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
@ -173,6 +179,32 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
self.assertEqual(results.get_result_with_max_attacker_advantage(),
|
||||||
self.perfect_classifier_result)
|
self.perfect_classifier_result)
|
||||||
|
|
||||||
|
def test_summary_by_slices(self):
|
||||||
|
results = AttackResults(
|
||||||
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
|
self.assertEqual(
|
||||||
|
results.summary(by_slices=True), 'Highest AUC on slice '
|
||||||
|
'SingleSliceSpec' +
|
||||||
|
'(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' +
|
||||||
|
'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' +
|
||||||
|
'Highest advantage on ' +
|
||||||
|
'slice SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED=True) ' +
|
||||||
|
'achieved by AttackType.THRESHOLD_ATTACK with an advantage of 1.0\n' +
|
||||||
|
'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||||
|
'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' +
|
||||||
|
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||||
|
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||||
|
|
||||||
|
def test_summary_without_slices(self):
|
||||||
|
results = AttackResults(
|
||||||
|
[self.perfect_classifier_result, self.random_classifier_result])
|
||||||
|
self.assertEqual(
|
||||||
|
results.summary(by_slices=False),
|
||||||
|
'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||||
|
'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' +
|
||||||
|
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
|
||||||
|
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue