diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index eb1d8db..76df7fc 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -29,6 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleRiskScoreResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature @@ -287,7 +288,21 @@ class SingleAttackResultTest(absltest.TestCase): self.assertEqual(result.get_attacker_advantage(), 0.0) + +class SingleRiskScoreResultTest(absltest.TestCase): + # Only a basic test to check the attack by setting a threshold on risk score. + def test_attack_with_varied_thresholds(self): + + result = SingleRiskScoreResult( + slice_spec=SingleSliceSpec(None), + train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]), + test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3])) + + self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[1], [0.8,0.625]) + self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[2], [0.8,1]) + + class AttackResultsCollectionTest(absltest.TestCase): def __init__(self, *args, **kwargs): diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index 4c80f49..fd4db2b 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -196,6 +196,15 @@ class RunAttacksTest(absltest.TestCase): np.testing.assert_almost_equal( seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) + def test_run_compute_privacy_risk_score_correct_score(self): + result = mia._compute_privacy_risk_score( + AttackInputData( + loss_train=np.array([1, 1, 1, 10, 100]), + loss_test=np.array([10, 100, 100, 1000, 10000]))) + + np.testing.assert_almost_equal(result.train_risk_scores, [1,1,1,0.5,0.33], decimal=2) + np.testing.assert_almost_equal(result.test_risk_scores, [0.5,0.33,0.33,0,0], decimal=2) + if __name__ == '__main__': absltest.main()