Sets training set as positive class for sklearn.metrics.roc_curve.

sklearn.metrics.roc_curve uses classification rules in the form "score >= threshold ==> predict positive".
When calling roc_curve, we used to label test data as positive class. This way, TPR = % test examples classified as test, FPR = % training examples classified as test. The classification rule is "loss >= threshold ==> predict test".

For membership inference, TPR is usually defined as % training examples classified as training, and FPR is % test examples classified as training.
As training samples usually have lower loss, we usually use rules in the form of "loss <= threshold ==> predict training".

Therefore, TPR in the 2nd case is actually (1 - FPR) in the 1st case, FPR in the 2nd case is (1 - TPR) in the 1st case.
This mismatch does not affect attacker advantage or AUC, but this can cause problem to PPV.

Now, we:
- set training set as positive class.
- for threshold and entropy attacks, set score to be -loss, so that higher score corresponds to training data.
- negate the thresholds (computed based on -loss) so that it corresponds to loss.

PiperOrigin-RevId: 519880043
This commit is contained in:
Shuang Song 2023-03-27 17:59:45 -07:00 committed by A. Unique TensorFlower
parent 7796369d8b
commit e125951c9b
3 changed files with 70 additions and 21 deletions

View file

@ -120,6 +120,9 @@ def _run_trained_attack(
assert not np.any(np.isnan(scores))
# Generate ROC curves with scores.
# Different from the threshold attacker which uses the loss, we do not negate
# the scores here, because the attacker returns the probability of the
# positive class.
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
# 'test_train_ratio' is the ratio of test data size to train data size. It is
# used to compute the Positive Predictive Value.
@ -131,7 +134,7 @@ def _run_trained_attack(
thresholds=thresholds,
test_train_ratio=test_train_ratio)
in_train_indices = (labels == 0)
in_train_indices = labels == 1
return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input),
data_size=prepared_attacker_data.data_size,
@ -154,8 +157,12 @@ def _run_threshold_attack(attack_input: AttackInputData):
loss_train = np.sum(loss_train, axis=1)
loss_test = np.sum(loss_test, axis=1)
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
np.concatenate((loss_train, loss_test)))
np.concatenate((np.ones(ntrain), np.zeros(ntest))),
# roc_curve uses classifier in the form of
# "score >= threshold ==> predict positive", while training data has lower
# loss, so we negate the loss.
-np.concatenate((loss_train, loss_test)),
)
# 'test_train_ratio' is the ratio of test data size to train data size. It is
# used to compute the Positive Predictive Value.
test_train_ratio = ntest / ntrain
@ -163,8 +170,9 @@ def _run_threshold_attack(attack_input: AttackInputData):
roc_curve = RocCurve(
tpr=tpr,
fpr=fpr,
thresholds=thresholds,
test_train_ratio=test_train_ratio)
thresholds=-thresholds, # negate because we negated the loss
test_train_ratio=test_train_ratio,
)
return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input),
@ -182,9 +190,13 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
'multilabel data.'))
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
np.concatenate(
(attack_input.get_entropy_train(), attack_input.get_entropy_test())))
np.concatenate((np.ones(ntrain), np.zeros(ntest))),
# Similar as in loss, we negate the entropy becase training examples are
# expected to have lower entropy.
-np.concatenate(
(attack_input.get_entropy_train(), attack_input.get_entropy_test())
),
)
# 'test_train_ratio' is the ratio of test data size to train data size. It is
# used to compute the Positive Predictive Value.
test_train_ratio = ntest / ntrain
@ -192,8 +204,9 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
roc_curve = RocCurve(
tpr=tpr,
fpr=fpr,
thresholds=thresholds,
test_train_ratio=test_train_ratio)
thresholds=-thresholds, # negate because we negated the loss
test_train_ratio=test_train_ratio,
)
return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input),

View file

@ -214,9 +214,11 @@ class RunAttacksTest(parameterized.TestCase):
result.membership_scores_test,
result.membership_scores_test[0],
rtol=1e-3)
# Training score should be smaller than test score
self.assertLess(result.membership_scores_train[0],
result.membership_scores_test[0])
# Training score should be larger than test score, as training set is set
# to be positive.
self.assertGreater(
result.membership_scores_train[0], result.membership_scores_test[0]
)
def test_run_attack_threshold_calculates_correct_auc(self):
result = mia._run_attack(
@ -236,12 +238,39 @@ class RunAttacksTest(parameterized.TestCase):
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
@parameterized.parameters(
[AttackType.THRESHOLD_ATTACK],
[AttackType.THRESHOLD_ENTROPY_ATTACK],
)
def test_calculates_correct_tpr_fpr(self, attack_type):
rng = np.random.RandomState(27)
loss_train = rng.rand(100)
loss_test = rng.rand(50) + 0.1
result = mia._run_attack(
AttackInputData(
loss_train=loss_train,
loss_test=loss_test,
entropy_train=loss_train,
entropy_test=loss_test,
),
attack_type,
)
self.assertEqual(attack_type, result.attack_type)
for tpr, fpr, threshold in zip(
result.roc_curve.tpr, result.roc_curve.fpr, result.roc_curve.thresholds
):
self.assertAlmostEqual(tpr, np.mean(loss_train <= threshold))
self.assertAlmostEqual(fpr, np.mean(loss_test <= threshold))
@mock.patch('sklearn.metrics.roc_curve')
def test_run_attack_threshold_entropy_small_tpr_fpr_correct_ppv(
self, patched_fn):
# sklearn.metrics.roc_curve returns (fpr, tpr, thresholds).
patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001,
0.0002], [0.2, 0.4, 0.6])
patched_fn.return_value = (
np.array([0.2, 0.04, 0.0003]),
np.array([0.1, 0.0001, 0.0002]),
np.array([0.2, 0.4, 0.6]),
)
result = mia._run_attack(
AttackInputData(
entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
@ -380,8 +409,11 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
def test_run_multilabel_attack_threshold_small_tpr_fpr_correct_ppv(
self, patched_fn):
# sklearn.metrics.roc_curve returns (fpr, tpr, thresholds).
patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001,
0.0002], [0.2, 0.4, 0.6])
patched_fn.return_value = (
np.array([0.2, 0.04, 0.0003]),
np.array([0.1, 0.0001, 0.0002]),
np.array([0.2, 0.4, 0.6]),
)
result = mia._run_attack(
AttackInputData(
loss_train=np.array([[0.1, 0.2], [1.3, 0.4], [0.5, 0.6], [0.9,

View file

@ -76,7 +76,7 @@ def create_attacker_data(attack_input_data: data_structures.AttackInputData,
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
features_all = np.concatenate((attack_input_train, attack_input_test))
labels_all = np.concatenate((np.zeros(ntrain), np.ones(ntest)))
labels_all = np.concatenate((np.ones(ntrain), np.zeros(ntest)))
if attack_input_data.has_nonnull_sample_weights():
sample_weights_all = np.concatenate((attack_input_data.sample_weight_train,
attack_input_data.sample_weight_test),
@ -282,13 +282,17 @@ class KNearestNeighborsAttacker(TrainedAttacker):
self.model = model
def create_attacker(attack_type,
backend: Optional[str] = None) -> TrainedAttacker:
def create_attacker(
attack_type: data_structures.AttackType, backend: Optional[str] = None
) -> TrainedAttacker:
"""Returns the corresponding attacker for the provided attack_type."""
# Compare by name instead of the variable itself to support module reload.
if attack_type.name == data_structures.AttackType.LOGISTIC_REGRESSION.name:
return LogisticRegressionAttacker(backend=backend)
if attack_type.name == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON.name:
if (
attack_type.name
== data_structures.AttackType.MULTI_LAYERED_PERCEPTRON.name
):
return MultilayerPerceptronAttacker(backend=backend)
if attack_type.name == data_structures.AttackType.RANDOM_FOREST.name:
return RandomForestAttacker(backend=backend)