forked from 626_privacy/tensorflow_privacy
Introduce concept of "membership scores".
PiperOrigin-RevId: 348443155
This commit is contained in:
parent
b208d9deec
commit
2c810440d9
3 changed files with 49 additions and 1 deletions
|
@ -425,7 +425,31 @@ class SingleAttackResult:
|
|||
slice_spec: SingleSliceSpec
|
||||
|
||||
attack_type: AttackType
|
||||
roc_curve: RocCurve # for drawing and metrics calculation
|
||||
|
||||
# NOTE: roc_curve could theoretically be derived from membership scores.
|
||||
# Currently, we store it explicitly since not all attack types support
|
||||
# membership scores.
|
||||
# TODO(b/175870479): Consider deriving ROC curve from the membership scores.
|
||||
|
||||
# ROC curve representing the accuracy of the attacker
|
||||
roc_curve: RocCurve
|
||||
|
||||
# Membership score is some measure of confidence of this attacker that
|
||||
# a particular sample is a member of the training set.
|
||||
#
|
||||
# This is NOT necessarily probability. The nature of this score depends on
|
||||
# the type of attacker. Scores from different attacker types are not directly
|
||||
# comparable, but can be compared in relative terms (e.g. considering order
|
||||
# imposed by this measure).
|
||||
#
|
||||
|
||||
# Membership scores for the training set samples. For a perfect attacker,
|
||||
# all training samples will have higher scores than test samples.
|
||||
membership_scores_train: np.ndarray = None
|
||||
|
||||
# Membership scores for the test set samples. For a perfect attacker, all
|
||||
# test set samples will have lower scores than the training set samples.
|
||||
membership_scores_test: np.ndarray = None
|
||||
|
||||
def get_attacker_advantage(self):
|
||||
return self.roc_curve.get_attacker_advantage()
|
||||
|
|
|
@ -76,6 +76,12 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
|
||||
# NOTE: In the current setup we can't obtain membership scores for all
|
||||
# samples, since some of them were used to train the attacker. This can be
|
||||
# fixed by training several attackers to ensure each sample was left out
|
||||
# in exactly one attacker (basically, this means performing cross-validation).
|
||||
# TODO(b/175870479): Implement membership scores for predicted attackers.
|
||||
|
||||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
attack_type=attack_type,
|
||||
|
@ -94,6 +100,8 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
|||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
membership_scores_train=-attack_input.get_loss_train(),
|
||||
membership_scores_test=-attack_input.get_loss_test(),
|
||||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
|
@ -109,6 +117,8 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
|||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
|
||||
membership_scores_train=-attack_input.get_entropy_train(),
|
||||
membership_scores_test=-attack_input.get_entropy_test(),
|
||||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
|
|
|
@ -62,6 +62,20 @@ class RunAttacksTest(absltest.TestCase):
|
|||
|
||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||
|
||||
def test_run_attack_threshold_sets_membership_scores(self):
|
||||
result = mia._run_attack(
|
||||
get_test_input(100, 50), AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertLen(result.membership_scores_train, 100)
|
||||
self.assertLen(result.membership_scores_test, 50)
|
||||
|
||||
def test_run_attack_threshold_entropy_sets_membership_scores(self):
|
||||
result = mia._run_attack(
|
||||
get_test_input(100, 50), AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||
|
||||
self.assertLen(result.membership_scores_train, 100)
|
||||
self.assertLen(result.membership_scores_test, 50)
|
||||
|
||||
def test_run_attack_threshold_calculates_correct_auc(self):
|
||||
result = mia._run_attack(
|
||||
AttackInputData(
|
||||
|
|
Loading…
Reference in a new issue