From e125951c9b390f2dde6bbf0685071f98e2c7da58 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Mon, 27 Mar 2023 17:59:45 -0700 Subject: [PATCH] Sets training set as positive class for sklearn.metrics.roc_curve. sklearn.metrics.roc_curve uses classification rules in the form "score >= threshold ==> predict positive". When calling roc_curve, we used to label test data as positive class. This way, TPR = % test examples classified as test, FPR = % training examples classified as test. The classification rule is "loss >= threshold ==> predict test". For membership inference, TPR is usually defined as % training examples classified as training, and FPR is % test examples classified as training. As training samples usually have lower loss, we usually use rules in the form of "loss <= threshold ==> predict training". Therefore, TPR in the 2nd case is actually (1 - FPR) in the 1st case, FPR in the 2nd case is (1 - TPR) in the 1st case. This mismatch does not affect attacker advantage or AUC, but this can cause problem to PPV. Now, we: - set training set as positive class. - for threshold and entropy attacks, set score to be -loss, so that higher score corresponds to training data. - negate the thresholds (computed based on -loss) so that it corresponds to loss. PiperOrigin-RevId: 519880043 --- .../membership_inference_attack.py | 33 +++++++++---- .../membership_inference_attack_test.py | 46 ++++++++++++++++--- .../membership_inference_attack/models.py | 12 +++-- 3 files changed, 70 insertions(+), 21 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index ef3bdc7..1552b5e 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -120,6 +120,9 @@ def _run_trained_attack( assert not np.any(np.isnan(scores)) # Generate ROC curves with scores. + # Different from the threshold attacker which uses the loss, we do not negate + # the scores here, because the attacker returns the probability of the + # positive class. fpr, tpr, thresholds = metrics.roc_curve(labels, scores) # 'test_train_ratio' is the ratio of test data size to train data size. It is # used to compute the Positive Predictive Value. @@ -131,7 +134,7 @@ def _run_trained_attack( thresholds=thresholds, test_train_ratio=test_train_ratio) - in_train_indices = (labels == 0) + in_train_indices = labels == 1 return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), data_size=prepared_attacker_data.data_size, @@ -154,8 +157,12 @@ def _run_threshold_attack(attack_input: AttackInputData): loss_train = np.sum(loss_train, axis=1) loss_test = np.sum(loss_test, axis=1) fpr, tpr, thresholds = metrics.roc_curve( - np.concatenate((np.zeros(ntrain), np.ones(ntest))), - np.concatenate((loss_train, loss_test))) + np.concatenate((np.ones(ntrain), np.zeros(ntest))), + # roc_curve uses classifier in the form of + # "score >= threshold ==> predict positive", while training data has lower + # loss, so we negate the loss. + -np.concatenate((loss_train, loss_test)), + ) # 'test_train_ratio' is the ratio of test data size to train data size. It is # used to compute the Positive Predictive Value. test_train_ratio = ntest / ntrain @@ -163,8 +170,9 @@ def _run_threshold_attack(attack_input: AttackInputData): roc_curve = RocCurve( tpr=tpr, fpr=fpr, - thresholds=thresholds, - test_train_ratio=test_train_ratio) + thresholds=-thresholds, # negate because we negated the loss + test_train_ratio=test_train_ratio, + ) return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), @@ -182,9 +190,13 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData): 'multilabel data.')) ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size() fpr, tpr, thresholds = metrics.roc_curve( - np.concatenate((np.zeros(ntrain), np.ones(ntest))), - np.concatenate( - (attack_input.get_entropy_train(), attack_input.get_entropy_test()))) + np.concatenate((np.ones(ntrain), np.zeros(ntest))), + # Similar as in loss, we negate the entropy becase training examples are + # expected to have lower entropy. + -np.concatenate( + (attack_input.get_entropy_train(), attack_input.get_entropy_test()) + ), + ) # 'test_train_ratio' is the ratio of test data size to train data size. It is # used to compute the Positive Predictive Value. test_train_ratio = ntest / ntrain @@ -192,8 +204,9 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData): roc_curve = RocCurve( tpr=tpr, fpr=fpr, - thresholds=thresholds, - test_train_ratio=test_train_ratio) + thresholds=-thresholds, # negate because we negated the loss + test_train_ratio=test_train_ratio, + ) return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py index c3ba19b..730df27 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py @@ -214,9 +214,11 @@ class RunAttacksTest(parameterized.TestCase): result.membership_scores_test, result.membership_scores_test[0], rtol=1e-3) - # Training score should be smaller than test score - self.assertLess(result.membership_scores_train[0], - result.membership_scores_test[0]) + # Training score should be larger than test score, as training set is set + # to be positive. + self.assertGreater( + result.membership_scores_train[0], result.membership_scores_test[0] + ) def test_run_attack_threshold_calculates_correct_auc(self): result = mia._run_attack( @@ -236,12 +238,39 @@ class RunAttacksTest(parameterized.TestCase): np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) + @parameterized.parameters( + [AttackType.THRESHOLD_ATTACK], + [AttackType.THRESHOLD_ENTROPY_ATTACK], + ) + def test_calculates_correct_tpr_fpr(self, attack_type): + rng = np.random.RandomState(27) + loss_train = rng.rand(100) + loss_test = rng.rand(50) + 0.1 + result = mia._run_attack( + AttackInputData( + loss_train=loss_train, + loss_test=loss_test, + entropy_train=loss_train, + entropy_test=loss_test, + ), + attack_type, + ) + self.assertEqual(attack_type, result.attack_type) + for tpr, fpr, threshold in zip( + result.roc_curve.tpr, result.roc_curve.fpr, result.roc_curve.thresholds + ): + self.assertAlmostEqual(tpr, np.mean(loss_train <= threshold)) + self.assertAlmostEqual(fpr, np.mean(loss_test <= threshold)) + @mock.patch('sklearn.metrics.roc_curve') def test_run_attack_threshold_entropy_small_tpr_fpr_correct_ppv( self, patched_fn): # sklearn.metrics.roc_curve returns (fpr, tpr, thresholds). - patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001, - 0.0002], [0.2, 0.4, 0.6]) + patched_fn.return_value = ( + np.array([0.2, 0.04, 0.0003]), + np.array([0.1, 0.0001, 0.0002]), + np.array([0.2, 0.4, 0.6]), + ) result = mia._run_attack( AttackInputData( entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]), @@ -380,8 +409,11 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase): def test_run_multilabel_attack_threshold_small_tpr_fpr_correct_ppv( self, patched_fn): # sklearn.metrics.roc_curve returns (fpr, tpr, thresholds). - patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001, - 0.0002], [0.2, 0.4, 0.6]) + patched_fn.return_value = ( + np.array([0.2, 0.04, 0.0003]), + np.array([0.1, 0.0001, 0.0002]), + np.array([0.2, 0.4, 0.6]), + ) result = mia._run_attack( AttackInputData( loss_train=np.array([[0.1, 0.2], [1.3, 0.4], [0.5, 0.6], [0.9, diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py index a05ef46..f74a9ef 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py @@ -76,7 +76,7 @@ def create_attacker_data(attack_input_data: data_structures.AttackInputData, ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0] features_all = np.concatenate((attack_input_train, attack_input_test)) - labels_all = np.concatenate((np.zeros(ntrain), np.ones(ntest))) + labels_all = np.concatenate((np.ones(ntrain), np.zeros(ntest))) if attack_input_data.has_nonnull_sample_weights(): sample_weights_all = np.concatenate((attack_input_data.sample_weight_train, attack_input_data.sample_weight_test), @@ -282,13 +282,17 @@ class KNearestNeighborsAttacker(TrainedAttacker): self.model = model -def create_attacker(attack_type, - backend: Optional[str] = None) -> TrainedAttacker: +def create_attacker( + attack_type: data_structures.AttackType, backend: Optional[str] = None +) -> TrainedAttacker: """Returns the corresponding attacker for the provided attack_type.""" # Compare by name instead of the variable itself to support module reload. if attack_type.name == data_structures.AttackType.LOGISTIC_REGRESSION.name: return LogisticRegressionAttacker(backend=backend) - if attack_type.name == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON.name: + if ( + attack_type.name + == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON.name + ): return MultilayerPerceptronAttacker(backend=backend) if attack_type.name == data_structures.AttackType.RANDOM_FOREST.name: return RandomForestAttacker(backend=backend)