forked from 626_privacy/tensorflow_privacy
Introduce concept of "membership scores".
PiperOrigin-RevId: 348443155
This commit is contained in:
parent
b208d9deec
commit
2c810440d9
3 changed files with 49 additions and 1 deletions
|
@ -425,7 +425,31 @@ class SingleAttackResult:
|
||||||
slice_spec: SingleSliceSpec
|
slice_spec: SingleSliceSpec
|
||||||
|
|
||||||
attack_type: AttackType
|
attack_type: AttackType
|
||||||
roc_curve: RocCurve # for drawing and metrics calculation
|
|
||||||
|
# NOTE: roc_curve could theoretically be derived from membership scores.
|
||||||
|
# Currently, we store it explicitly since not all attack types support
|
||||||
|
# membership scores.
|
||||||
|
# TODO(b/175870479): Consider deriving ROC curve from the membership scores.
|
||||||
|
|
||||||
|
# ROC curve representing the accuracy of the attacker
|
||||||
|
roc_curve: RocCurve
|
||||||
|
|
||||||
|
# Membership score is some measure of confidence of this attacker that
|
||||||
|
# a particular sample is a member of the training set.
|
||||||
|
#
|
||||||
|
# This is NOT necessarily probability. The nature of this score depends on
|
||||||
|
# the type of attacker. Scores from different attacker types are not directly
|
||||||
|
# comparable, but can be compared in relative terms (e.g. considering order
|
||||||
|
# imposed by this measure).
|
||||||
|
#
|
||||||
|
|
||||||
|
# Membership scores for the training set samples. For a perfect attacker,
|
||||||
|
# all training samples will have higher scores than test samples.
|
||||||
|
membership_scores_train: np.ndarray = None
|
||||||
|
|
||||||
|
# Membership scores for the test set samples. For a perfect attacker, all
|
||||||
|
# test set samples will have lower scores than the training set samples.
|
||||||
|
membership_scores_test: np.ndarray = None
|
||||||
|
|
||||||
def get_attacker_advantage(self):
|
def get_attacker_advantage(self):
|
||||||
return self.roc_curve.get_attacker_advantage()
|
return self.roc_curve.get_attacker_advantage()
|
||||||
|
|
|
@ -76,6 +76,12 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
|
|
||||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||||
|
|
||||||
|
# NOTE: In the current setup we can't obtain membership scores for all
|
||||||
|
# samples, since some of them were used to train the attacker. This can be
|
||||||
|
# fixed by training several attackers to ensure each sample was left out
|
||||||
|
# in exactly one attacker (basically, this means performing cross-validation).
|
||||||
|
# TODO(b/175870479): Implement membership scores for predicted attackers.
|
||||||
|
|
||||||
return SingleAttackResult(
|
return SingleAttackResult(
|
||||||
slice_spec=_get_slice_spec(attack_input),
|
slice_spec=_get_slice_spec(attack_input),
|
||||||
attack_type=attack_type,
|
attack_type=attack_type,
|
||||||
|
@ -94,6 +100,8 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
||||||
return SingleAttackResult(
|
return SingleAttackResult(
|
||||||
slice_spec=_get_slice_spec(attack_input),
|
slice_spec=_get_slice_spec(attack_input),
|
||||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
|
membership_scores_train=-attack_input.get_loss_train(),
|
||||||
|
membership_scores_test=-attack_input.get_loss_test(),
|
||||||
roc_curve=roc_curve)
|
roc_curve=roc_curve)
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,6 +117,8 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
||||||
return SingleAttackResult(
|
return SingleAttackResult(
|
||||||
slice_spec=_get_slice_spec(attack_input),
|
slice_spec=_get_slice_spec(attack_input),
|
||||||
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
|
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
|
||||||
|
membership_scores_train=-attack_input.get_entropy_train(),
|
||||||
|
membership_scores_test=-attack_input.get_entropy_test(),
|
||||||
roc_curve=roc_curve)
|
roc_curve=roc_curve)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,20 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK)
|
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_sets_membership_scores(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
get_test_input(100, 50), AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
|
self.assertLen(result.membership_scores_train, 100)
|
||||||
|
self.assertLen(result.membership_scores_test, 50)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_entropy_sets_membership_scores(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
get_test_input(100, 50), AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
|
self.assertLen(result.membership_scores_train, 100)
|
||||||
|
self.assertLen(result.membership_scores_test, 50)
|
||||||
|
|
||||||
def test_run_attack_threshold_calculates_correct_auc(self):
|
def test_run_attack_threshold_calculates_correct_auc(self):
|
||||||
result = mia._run_attack(
|
result = mia._run_attack(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
|
|
Loading…
Reference in a new issue