Introduce concept of "membership scores".

PiperOrigin-RevId: 348443155
This commit is contained in:
Yurii Sushko 2020-12-21 03:42:35 -08:00 committed by A. Unique TensorFlower
parent b208d9deec
commit 2c810440d9
3 changed files with 49 additions and 1 deletions

View file

@ -425,7 +425,31 @@ class SingleAttackResult:
slice_spec: SingleSliceSpec slice_spec: SingleSliceSpec
attack_type: AttackType attack_type: AttackType
roc_curve: RocCurve # for drawing and metrics calculation
# NOTE: roc_curve could theoretically be derived from membership scores.
# Currently, we store it explicitly since not all attack types support
# membership scores.
# TODO(b/175870479): Consider deriving ROC curve from the membership scores.
# ROC curve representing the accuracy of the attacker
roc_curve: RocCurve
# Membership score is some measure of confidence of this attacker that
# a particular sample is a member of the training set.
#
# This is NOT necessarily probability. The nature of this score depends on
# the type of attacker. Scores from different attacker types are not directly
# comparable, but can be compared in relative terms (e.g. considering order
# imposed by this measure).
#
# Membership scores for the training set samples. For a perfect attacker,
# all training samples will have higher scores than test samples.
membership_scores_train: np.ndarray = None
# Membership scores for the test set samples. For a perfect attacker, all
# test set samples will have lower scores than the training set samples.
membership_scores_test: np.ndarray = None
def get_attacker_advantage(self): def get_attacker_advantage(self):
return self.roc_curve.get_attacker_advantage() return self.roc_curve.get_attacker_advantage()

View file

@ -76,6 +76,12 @@ def _run_trained_attack(attack_input: AttackInputData,
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
# NOTE: In the current setup we can't obtain membership scores for all
# samples, since some of them were used to train the attacker. This can be
# fixed by training several attackers to ensure each sample was left out
# in exactly one attacker (basically, this means performing cross-validation).
# TODO(b/175870479): Implement membership scores for predicted attackers.
return SingleAttackResult( return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input), slice_spec=_get_slice_spec(attack_input),
attack_type=attack_type, attack_type=attack_type,
@ -94,6 +100,8 @@ def _run_threshold_attack(attack_input: AttackInputData):
return SingleAttackResult( return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input), slice_spec=_get_slice_spec(attack_input),
attack_type=AttackType.THRESHOLD_ATTACK, attack_type=AttackType.THRESHOLD_ATTACK,
membership_scores_train=-attack_input.get_loss_train(),
membership_scores_test=-attack_input.get_loss_test(),
roc_curve=roc_curve) roc_curve=roc_curve)
@ -109,6 +117,8 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
return SingleAttackResult( return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input), slice_spec=_get_slice_spec(attack_input),
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK, attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
membership_scores_train=-attack_input.get_entropy_train(),
membership_scores_test=-attack_input.get_entropy_test(),
roc_curve=roc_curve) roc_curve=roc_curve)

View file

@ -62,6 +62,20 @@ class RunAttacksTest(absltest.TestCase):
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK) self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK)
def test_run_attack_threshold_sets_membership_scores(self):
result = mia._run_attack(
get_test_input(100, 50), AttackType.THRESHOLD_ATTACK)
self.assertLen(result.membership_scores_train, 100)
self.assertLen(result.membership_scores_test, 50)
def test_run_attack_threshold_entropy_sets_membership_scores(self):
result = mia._run_attack(
get_test_input(100, 50), AttackType.THRESHOLD_ENTROPY_ATTACK)
self.assertLen(result.membership_scores_train, 100)
self.assertLen(result.membership_scores_test, 50)
def test_run_attack_threshold_calculates_correct_auc(self): def test_run_attack_threshold_calculates_correct_auc(self):
result = mia._run_attack( result = mia._run_attack(
AttackInputData( AttackInputData(