Internal change.

PiperOrigin-RevId: 325210305
This commit is contained in:
A. Unique TensorFlower 2020-08-06 05:25:32 -07:00
parent e91c820b2a
commit 08f960a1af

View file

@ -21,6 +21,8 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
class AttackInputDataTest(absltest.TestCase): class AttackInputDataTest(absltest.TestCase):
@ -106,7 +108,9 @@ class SingleAttackResultTest(absltest.TestCase):
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]))
result = SingleAttackResult( result = SingleAttackResult(
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK) roc_curve=roc,
slice_spec=SingleSliceSpec(None),
attack_type=AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.get_auc(), 0.5) self.assertEqual(result.get_auc(), 0.5)
@ -118,7 +122,9 @@ class SingleAttackResultTest(absltest.TestCase):
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]))
result = SingleAttackResult( result = SingleAttackResult(
roc_curve=roc, slice_spec=None, attack_type=AttackType.THRESHOLD_ATTACK) roc_curve=roc,
slice_spec=SingleSliceSpec(None),
attack_type=AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.get_attacker_advantage(), 0.0) self.assertEqual(result.get_attacker_advantage(), 0.0)
@ -133,7 +139,7 @@ class AttackResultsTest(absltest.TestCase):
# ROC curve of a perfect classifier # ROC curve of a perfect classifier
self.perfect_classifier_result = SingleAttackResult( self.perfect_classifier_result = SingleAttackResult(
slice_spec=None, slice_spec=SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True),
attack_type=AttackType.THRESHOLD_ATTACK, attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 1.0, 1.0]), tpr=np.array([0.0, 1.0, 1.0]),
@ -142,7 +148,7 @@ class AttackResultsTest(absltest.TestCase):
# ROC curve of a random classifier # ROC curve of a random classifier
self.random_classifier_result = SingleAttackResult( self.random_classifier_result = SingleAttackResult(
slice_spec=None, slice_spec=SingleSliceSpec(None),
attack_type=AttackType.THRESHOLD_ATTACK, attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
@ -173,6 +179,32 @@ class AttackResultsTest(absltest.TestCase):
self.assertEqual(results.get_result_with_max_attacker_advantage(), self.assertEqual(results.get_result_with_max_attacker_advantage(),
self.perfect_classifier_result) self.perfect_classifier_result)
def test_summary_by_slices(self):
results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual(
results.summary(by_slices=True), 'Highest AUC on slice '
'SingleSliceSpec' +
'(SlicingFeature.CORRECTLY_CLASSIFIED=True) achieved by ' +
'AttackType.THRESHOLD_ATTACK with an AUC of 1.0\n' +
'Highest advantage on ' +
'slice SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED=True) ' +
'achieved by AttackType.THRESHOLD_ATTACK with an advantage of 1.0\n' +
'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' +
'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' +
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
def test_summary_without_slices(self):
results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual(
results.summary(by_slices=False),
'Highest AUC on slice SingleSliceSpec(Entire dataset) achieved ' +
'by AttackType.THRESHOLD_ATTACK with an AUC of 0.5\n' +
'Highest advantage on slice SingleSliceSpec(Entire dataset) achieved ' +
'by AttackType.THRESHOLD_ATTACK with an advantage of 0.0')
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()