forked from 626_privacy/tensorflow_privacy
add test cases for privacy risk score
This commit is contained in:
parent
d80df35e85
commit
bf65f55382
2 changed files with 24 additions and 0 deletions
|
@ -29,6 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleRiskScoreResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
|
||||
|
@ -288,6 +289,20 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||
|
||||
|
||||
class SingleRiskScoreResultTest(absltest.TestCase):
|
||||
|
||||
# Only a basic test to check the attack by setting a threshold on risk score.
|
||||
def test_attack_with_varied_thresholds(self):
|
||||
|
||||
result = SingleRiskScoreResult(
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]),
|
||||
test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3]))
|
||||
|
||||
self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[1], [0.8,0.625])
|
||||
self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[2], [0.8,1])
|
||||
|
||||
|
||||
class AttackResultsCollectionTest(absltest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -196,6 +196,15 @@ class RunAttacksTest(absltest.TestCase):
|
|||
np.testing.assert_almost_equal(
|
||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||
|
||||
def test_run_compute_privacy_risk_score_correct_score(self):
|
||||
result = mia._compute_privacy_risk_score(
|
||||
AttackInputData(
|
||||
loss_train=np.array([1, 1, 1, 10, 100]),
|
||||
loss_test=np.array([10, 100, 100, 1000, 10000])))
|
||||
|
||||
np.testing.assert_almost_equal(result.train_risk_scores, [1,1,1,0.5,0.33], decimal=2)
|
||||
np.testing.assert_almost_equal(result.test_risk_scores, [0.5,0.33,0.33,0,0], decimal=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue