diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 1140611..5849f73 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -425,7 +425,31 @@ class SingleAttackResult: slice_spec: SingleSliceSpec attack_type: AttackType - roc_curve: RocCurve # for drawing and metrics calculation + + # NOTE: roc_curve could theoretically be derived from membership scores. + # Currently, we store it explicitly since not all attack types support + # membership scores. + # TODO(b/175870479): Consider deriving ROC curve from the membership scores. + + # ROC curve representing the accuracy of the attacker + roc_curve: RocCurve + + # Membership score is some measure of confidence of this attacker that + # a particular sample is a member of the training set. + # + # This is NOT necessarily probability. The nature of this score depends on + # the type of attacker. Scores from different attacker types are not directly + # comparable, but can be compared in relative terms (e.g. considering order + # imposed by this measure). + # + + # Membership scores for the training set samples. For a perfect attacker, + # all training samples will have higher scores than test samples. + membership_scores_train: np.ndarray = None + + # Membership scores for the test set samples. For a perfect attacker, all + # test set samples will have lower scores than the training set samples. + membership_scores_test: np.ndarray = None def get_attacker_advantage(self): return self.roc_curve.get_attacker_advantage() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index 369482b..fe9a588 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -76,6 +76,12 @@ def _run_trained_attack(attack_input: AttackInputData, roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + # NOTE: In the current setup we can't obtain membership scores for all + # samples, since some of them were used to train the attacker. This can be + # fixed by training several attackers to ensure each sample was left out + # in exactly one attacker (basically, this means performing cross-validation). + # TODO(b/175870479): Implement membership scores for predicted attackers. + return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), attack_type=attack_type, @@ -94,6 +100,8 @@ def _run_threshold_attack(attack_input: AttackInputData): return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), attack_type=AttackType.THRESHOLD_ATTACK, + membership_scores_train=-attack_input.get_loss_train(), + membership_scores_test=-attack_input.get_loss_test(), roc_curve=roc_curve) @@ -109,6 +117,8 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData): return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK, + membership_scores_train=-attack_input.get_entropy_train(), + membership_scores_test=-attack_input.get_entropy_test(), roc_curve=roc_curve) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index 06f7672..86dd918 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -62,6 +62,20 @@ class RunAttacksTest(absltest.TestCase): self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK) + def test_run_attack_threshold_sets_membership_scores(self): + result = mia._run_attack( + get_test_input(100, 50), AttackType.THRESHOLD_ATTACK) + + self.assertLen(result.membership_scores_train, 100) + self.assertLen(result.membership_scores_test, 50) + + def test_run_attack_threshold_entropy_sets_membership_scores(self): + result = mia._run_attack( + get_test_input(100, 50), AttackType.THRESHOLD_ENTROPY_ATTACK) + + self.assertLen(result.membership_scores_train, 100) + self.assertLen(result.membership_scores_test, 50) + def test_run_attack_threshold_calculates_correct_auc(self): result = mia._run_attack( AttackInputData(