add test cases for privacy risk score

This commit is contained in:
Liwei Song 2020-12-02 21:00:44 -05:00
parent d80df35e85
commit bf65f55382
2 changed files with 24 additions and 0 deletions

View file

@ -29,6 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleRiskScoreResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
@ -287,7 +288,21 @@ class SingleAttackResultTest(absltest.TestCase):
self.assertEqual(result.get_attacker_advantage(), 0.0) self.assertEqual(result.get_attacker_advantage(), 0.0)
class SingleRiskScoreResultTest(absltest.TestCase):
# Only a basic test to check the attack by setting a threshold on risk score.
def test_attack_with_varied_thresholds(self):
result = SingleRiskScoreResult(
slice_spec=SingleSliceSpec(None),
train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]),
test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3]))
self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[1], [0.8,0.625])
self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[2], [0.8,1])
class AttackResultsCollectionTest(absltest.TestCase): class AttackResultsCollectionTest(absltest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View file

@ -196,6 +196,15 @@ class RunAttacksTest(absltest.TestCase):
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
def test_run_compute_privacy_risk_score_correct_score(self):
result = mia._compute_privacy_risk_score(
AttackInputData(
loss_train=np.array([1, 1, 1, 10, 100]),
loss_test=np.array([10, 100, 100, 1000, 10000])))
np.testing.assert_almost_equal(result.train_risk_scores, [1,1,1,0.5,0.33], decimal=2)
np.testing.assert_almost_equal(result.test_risk_scores, [0.5,0.33,0.33,0,0], decimal=2)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()