forked from 626_privacy/tensorflow_privacy
add test cases for privacy risk score
This commit is contained in:
parent
d80df35e85
commit
bf65f55382
2 changed files with 24 additions and 0 deletions
|
@ -29,6 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleRiskScoreResult
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
|
|
||||||
|
@ -288,6 +289,20 @@ class SingleAttackResultTest(absltest.TestCase):
|
||||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleRiskScoreResultTest(absltest.TestCase):
|
||||||
|
|
||||||
|
# Only a basic test to check the attack by setting a threshold on risk score.
|
||||||
|
def test_attack_with_varied_thresholds(self):
|
||||||
|
|
||||||
|
result = SingleRiskScoreResult(
|
||||||
|
slice_spec=SingleSliceSpec(None),
|
||||||
|
train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]),
|
||||||
|
test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3]))
|
||||||
|
|
||||||
|
self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[1], [0.8,0.625])
|
||||||
|
self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[2], [0.8,1])
|
||||||
|
|
||||||
|
|
||||||
class AttackResultsCollectionTest(absltest.TestCase):
|
class AttackResultsCollectionTest(absltest.TestCase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
|
@ -196,6 +196,15 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
np.testing.assert_almost_equal(
|
np.testing.assert_almost_equal(
|
||||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||||
|
|
||||||
|
def test_run_compute_privacy_risk_score_correct_score(self):
|
||||||
|
result = mia._compute_privacy_risk_score(
|
||||||
|
AttackInputData(
|
||||||
|
loss_train=np.array([1, 1, 1, 10, 100]),
|
||||||
|
loss_test=np.array([10, 100, 100, 1000, 10000])))
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(result.train_risk_scores, [1,1,1,0.5,0.33], decimal=2)
|
||||||
|
np.testing.assert_almost_equal(result.test_risk_scores, [0.5,0.33,0.33,0,0], decimal=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue