From 2b5d5b6ef559c81f1be49719dee8ce7da57a6255 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 16 Jul 2022 16:30:17 -0700 Subject: [PATCH] Add Positive Predictive Value as a metric for membership attack models performance on imbalanced data. PiperOrigin-RevId: 461390184 --- .../data_structures.py | 115 ++++++++++++- .../data_structures_test.py | 156 ++++++++++++++---- .../keras_evaluation_test.py | 2 +- .../membership_inference_attack.py | 36 +++- .../membership_inference_attack_test.py | 67 ++++++++ .../privacy_report_test.py | 6 +- .../tf_estimator_evaluation_test.py | 4 +- 7 files changed, 338 insertions(+), 48 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index 02658b1..c124656 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -26,7 +26,10 @@ import numpy as np import pandas as pd from scipy import special from sklearn import metrics -import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils as utils +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils + +# The minimum TPR or FPR below which they are considered equal. +_ABSOLUTE_TOLERANCE = 1e-3 ENTIRE_DATASET_SLICE_STR = 'Entire dataset' @@ -116,7 +119,6 @@ class AttackType(enum.Enum): K_NEAREST_NEIGHBORS = 'knn' THRESHOLD_ATTACK = 'threshold' THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy' - TF_LOGISTIC_REGRESSION = 'tf_lr' @property def is_trained_attack(self): @@ -133,6 +135,7 @@ class PrivacyMetric(enum.Enum): """An enum for the supported privacy risk metrics.""" AUC = 'AUC' ATTACKER_ADVANTAGE = 'Attacker advantage' + PPV = 'Positive predictive value' def __str__(self): """Returns 'AUC' instead of PrivacyMetric.AUC.""" @@ -627,6 +630,11 @@ class RocCurve: # False positive rates based on thresholds fpr: np.ndarray + # Ratio of test to train set size. + # In Jayaraman et al. (https://arxiv.org/pdf/2005.10881.pdf) it is referred to + # as 'gamma' (see Table 1 for the definition). + test_train_ratio: np.float64 + def get_auc(self): """Calculates area under curve (aka AUC).""" return metrics.auc(self.fpr, self.tpr) @@ -643,12 +651,69 @@ class RocCurve: """ return max(np.abs(self.tpr - self.fpr)) + def get_ppv(self) -> float: + """Calculates Positive Predictive Value of the membership attacker. + + The Positive Predictive Value (PPV) is the proportion of positive + predictions that are true positives. It can be expressed as PPV=TP/(TP+FP). + It was suggested by Jayaraman et al. (https://arxiv.org/pdf/2005.10881.pdf) + that this would be a suitable metric for membership attacks on datasets + where the number of samples from the training set and the number of samples + from the test set are very different. These are referred to as imbalanced + datasets. + + Returns: + A single float number for the Positive Predictive Value. + """ + + # The Positive Predictive Value (PPV) is the proportion of positive + # predictions that are true positives. It is expressed as PPV=TP/(TP+FP). + # It was suggested by Jayaraman et al. + # (https://arxiv.org/pdf/2005.10881.pdf) that this would be a suitable + # metric for membership attack models trained on datasets where the number + # of samples from the training set and the number of samples from the test + # set are very different. These are referred to as imbalanced datasets. + num = np.asarray(self.tpr) + den = num + np.asarray([r * self.test_train_ratio for r in self.fpr]) + # There is a special case when both `num` and `den` are 0. Both would be 0 + # when TPR and FPR are both 0, since test_train_ratio is strictly positive + # (exclude the case when there is no test set). Then TPR = 0 means that all + # positive (train set) examples are misclassified and FPR = 0 means that all + # negatives (test set) were correctly classified. + # Consider that when TPR and FPR are close to 0, TPR ~ FPR. Call this value + # 'R'. So the expression for PPV can be rewritten as: + # PPV = R / (R + test_train_ratio * R) = 1 / (1 + test_train_ratio). + # We can check this expression when test_train_ratio = 0, i.e. there is no + # test set, then PPV = 1 (perfect classification). When + # test_train_ratio >> 0, i,e, the test set size >> train set size, and + # PPV = 0 (perfect mis-classification). + # When test_train_ratio = 1, test and train sets are of the same size, and + # PPV = 0.5 (random guessing). This is because TPR = 0 means all positives + # are misclassified (i.e. classified as negatives) and FPR = 0 means all + # negatives are correctly classified (i.e. classified as neegatives). + # The normal case is when both `num` and `den` are not 0, and PPV is just + # the ratio of `num` to `den`. + + # Find when `tpr` and `fpr` are 0. + tpr_is_0 = np.isclose(self.tpr, 0.0, atol=_ABSOLUTE_TOLERANCE) + fpr_is_0 = np.isclose(self.fpr, 0.0, atol=_ABSOLUTE_TOLERANCE) + tpr_and_fpr_both_0 = np.logical_and(tpr_is_0, fpr_is_0) + # PPV when both are zero is given by the expression below. + ppv_when_tpr_fpr_both_0 = 1. / (1. + self.test_train_ratio) + # PPV when one is not zero is given by the expression below. + ppv_when_one_of_tpr_fpr_not_0 = np.divide( + num, den, out=np.zeros_like(den), where=den != 0) + return np.max( + np.where(tpr_and_fpr_both_0, ppv_when_tpr_fpr_both_0, + ppv_when_one_of_tpr_fpr_not_0)) + def __str__(self): - """Returns AUC and advantage metrics.""" + """Returns AUC, advantage and PPV metrics.""" return '\n'.join([ 'RocCurve(', ' AUC: %.2f' % self.get_auc(), - ' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')' + ' Attacker advantage: %.2f' % self.get_attacker_advantage(), + ' Positive predictive value: %.2f' % self.get_ppv(), ')' ]) @@ -695,6 +760,11 @@ class SingleAttackResult: def get_attacker_advantage(self): return self.roc_curve.get_attacker_advantage() + def get_ppv(self) -> float: + if self.data_size.ntrain == 0: + raise ValueError('Size of the training data cannot be zero.') + return self.roc_curve.get_ppv() + def get_auc(self): return self.roc_curve.get_auc() @@ -707,7 +777,8 @@ class SingleAttackResult: (self.data_size.ntrain, self.data_size.ntest), ' AttackType: %s' % str(self.attack_type), ' AUC: %.2f' % self.get_auc(), - ' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')' + ' Attacker advantage: %.2f' % self.get_attacker_advantage(), + ' Positive Predictive Value: %.2f' % self.get_ppv(), ')' ]) @@ -791,7 +862,14 @@ class SingleMembershipProbabilityResult: np.zeros(len(self.test_membership_probs)))), np.concatenate( (self.train_membership_probs, self.test_membership_probs))) - roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + ntrain = np.shape(self.train_membership_probs)[0] + ntest = np.shape(self.test_membership_probs)[0] + test_train_ratio = ntest / ntrain + roc_curve = RocCurve( + tpr=tpr, + fpr=fpr, + thresholds=thresholds, + test_train_ratio=test_train_ratio) summary.append( ' thresholding on membership probability achieved an AUC of %.2f' % (roc_curve.get_auc())) @@ -860,6 +938,7 @@ class AttackResults: data_size_test = [] attack_types = [] advantages = [] + ppvs = [] aucs = [] for attack_result in self.single_attack_results: @@ -874,6 +953,7 @@ class AttackResults: data_size_test.append(attack_result.data_size.ntest) attack_types.append(str(attack_result.attack_type)) advantages.append(float(attack_result.get_attacker_advantage())) + ppvs.append(float(attack_result.get_ppv())) aucs.append(float(attack_result.get_auc())) df = pd.DataFrame({ @@ -883,6 +963,7 @@ class AttackResults: str(AttackResultsDFColumns.DATA_SIZE_TEST): data_size_test, str(AttackResultsDFColumns.ATTACK_TYPE): attack_types, str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages, + str(PrivacyMetric.PPV): ppvs, str(PrivacyMetric.AUC): aucs }) return df @@ -918,6 +999,14 @@ class AttackResults: max_advantage_result_all.get_attacker_advantage(), max_advantage_result_all.slice_spec)) + max_ppv_result_all = self.get_result_with_max_ppv() + summary.append( + ' %s (with %d training and %d test examples) achieved a positive ' + 'predictive value of %.2f on slice %s' % + (max_ppv_result_all.attack_type, max_ppv_result_all.data_size.ntrain, + max_ppv_result_all.data_size.ntest, max_ppv_result_all.get_ppv(), + max_ppv_result_all.slice_spec)) + slice_dict = self._group_results_by_slice() if by_slices and len(slice_dict.keys()) > 1: @@ -937,6 +1026,12 @@ class AttackResults: max_advantage_result.data_size.ntrain, max_auc_result.data_size.ntest, max_advantage_result.get_attacker_advantage())) + max_ppv_result = results.get_result_with_max_ppv() + summary.append( + ' %s (with %d training and %d test examples) achieved a positive ' + 'predictive value of %.2f' % + (max_ppv_result.attack_type, max_ppv_result.data_size.ntrain, + max_ppv_result.data_size.ntest, max_ppv_result.get_ppv())) return '\n'.join(summary) @@ -966,6 +1061,11 @@ class AttackResults: result.get_attacker_advantage() for result in self.single_attack_results ])] + def get_result_with_max_ppv(self) -> SingleAttackResult: + """Gets the result with max positive predictive value for all attacks and slices.""" + return self.single_attack_results[np.argmax( + [result.get_ppv() for result in self.single_attack_results])] + def save(self, filepath): """Saves self to a pickle file.""" with open(filepath, 'wb') as out: @@ -1035,6 +1135,7 @@ def get_flattened_attack_metrics(results: AttackResults): attack_metrics += ['adv', 'auc'] values += [ float(attack_result.get_attacker_advantage()), - float(attack_result.get_auc()) + float(attack_result.get_auc()), + float(attack_result.get_ppv()), ] return types, slices, attack_metrics, values diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index f807a62..cb4369c 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -14,6 +14,7 @@ import os import tempfile +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -394,13 +395,14 @@ class AttackInputDataTest(parameterized.TestCase): '"force_multilabel_data" is True but "multilabel_data" is False.') -class RocCurveTest(absltest.TestCase): +class RocCurveTest(parameterized.TestCase): def test_auc_random_classifier(self): roc = RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]), - thresholds=np.array([0, 1, 2])) + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) self.assertEqual(roc.get_auc(), 0.5) @@ -408,7 +410,8 @@ class RocCurveTest(absltest.TestCase): roc = RocCurve( tpr=np.array([0.0, 1.0, 1.0]), fpr=np.array([1.0, 1.0, 0.0]), - thresholds=np.array([0, 1, 2])) + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) self.assertEqual(roc.get_auc(), 1.0) @@ -416,7 +419,8 @@ class RocCurveTest(absltest.TestCase): roc = RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]), - thresholds=np.array([0, 1, 2])) + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) self.assertEqual(roc.get_attacker_advantage(), 0.0) @@ -424,10 +428,59 @@ class RocCurveTest(absltest.TestCase): roc = RocCurve( tpr=np.array([0.0, 1.0, 1.0]), fpr=np.array([1.0, 1.0, 0.0]), - thresholds=np.array([0, 1, 2])) + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) self.assertEqual(roc.get_auc(), 1.0) + def test_ppv_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([0.0, 0.5, 1.0]), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) + + self.assertEqual(roc.get_ppv(), 0.5) + + def test_ppv_perfect_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 1.0, 1.0]), + fpr=np.array([1.0, 1.0, 0.0]), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) + + self.assertEqual(roc.get_ppv(), 1.0) + + # Parameters to test: test-train ratio, expected PPV. + @parameterized.named_parameters( + ('test_train_ratio_small', 0.001, 1.0), + ('test_train_ratio_large', 1000.0, 0.0), + ) + @mock.patch( + 'tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures._ABSOLUTE_TOLERANCE', + 1e-4) + def test_ppv_perfect_classifier_when_tpr_fpr_small(self, test_train_ratio, + expected_ppv): + roc = RocCurve( + tpr=np.array([0.00001, 0.0001, 0.002]), + fpr=np.array([0.00002, 0.0002, 0.002]), + thresholds=np.array([0, 1, 2]), + test_train_ratio=test_train_ratio) + + np.testing.assert_allclose(roc.get_ppv(), expected_ppv, atol=1e-3) + + @mock.patch( + 'tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures._ABSOLUTE_TOLERANCE', + 1e-4) + def test_ppv_random_classifier_when_tpr_fpr_small_and_test_train_is_1(self): + roc = RocCurve( + tpr=np.array([0.00001, 0.0001, 0.002]), + fpr=np.array([0.00002, 0.0002, 0.002]), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) + + np.testing.assert_allclose(roc.get_ppv(), 0.5, atol=1e-3) + class SingleAttackResultTest(absltest.TestCase): @@ -436,7 +489,8 @@ class SingleAttackResultTest(absltest.TestCase): roc = RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]), - thresholds=np.array([0, 1, 2])) + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) result = SingleAttackResult( roc_curve=roc, @@ -451,7 +505,8 @@ class SingleAttackResultTest(absltest.TestCase): roc = RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]), - thresholds=np.array([0, 1, 2])) + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) result = SingleAttackResult( roc_curve=roc, @@ -461,6 +516,22 @@ class SingleAttackResultTest(absltest.TestCase): self.assertEqual(result.get_attacker_advantage(), 0.0) + # Only a basic test, as this method calls RocCurve which is tested separately. + def test_ppv_random_classifier(self): + roc = RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([0.0, 0.5, 1.0]), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0) + + result = SingleAttackResult( + roc_curve=roc, + slice_spec=SingleSliceSpec(None), + attack_type=AttackType.THRESHOLD_ATTACK, + data_size=DataSize(ntrain=1, ntest=1)) + + self.assertEqual(result.get_ppv(), 0.5) + class SingleMembershipProbabilityResultTest(absltest.TestCase): @@ -492,7 +563,8 @@ class AttackResultsCollectionTest(absltest.TestCase): roc_curve=RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]), - thresholds=np.array([0, 1, 2])), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0), data_size=DataSize(ntrain=1, ntest=1)) self.results_epoch_10 = AttackResults( @@ -552,7 +624,8 @@ class AttackResultsTest(absltest.TestCase): roc_curve=RocCurve( tpr=np.array([0.0, 1.0, 1.0]), fpr=np.array([1.0, 1.0, 0.0]), - thresholds=np.array([0, 1, 2])), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0), data_size=DataSize(ntrain=1, ntest=1)) # ROC curve of a random classifier @@ -562,7 +635,8 @@ class AttackResultsTest(absltest.TestCase): roc_curve=RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]), - thresholds=np.array([0, 1, 2])), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0), data_size=DataSize(ntrain=1, ntest=1)) def test_get_result_with_max_auc_first(self): @@ -589,37 +663,58 @@ class AttackResultsTest(absltest.TestCase): self.assertEqual(results.get_result_with_max_attacker_advantage(), self.perfect_classifier_result) + def test_get_result_with_max_positive_predictive_value_first(self): + results = AttackResults( + [self.perfect_classifier_result, self.random_classifier_result]) + self.assertEqual(results.get_result_with_max_ppv(), + self.perfect_classifier_result) + + def test_get_result_with_max_positive_predictive_value_second(self): + results = AttackResults( + [self.random_classifier_result, self.perfect_classifier_result]) + self.assertEqual(results.get_result_with_max_ppv(), + self.perfect_classifier_result) + def test_summary_by_slices(self): results = AttackResults( [self.perfect_classifier_result, self.random_classifier_result]) - self.assertEqual( + self.assertSequenceEqual( results.summary(by_slices=True), 'Best-performing attacks over all slices\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' ' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + - ' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True\n\n' + - 'Best-performing attacks over slice: "CORRECTLY_CLASSIFIED=True"\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + - ' AUC of 1.00\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + - ' advantage of 1.00\n\n' + - 'Best-performing attacks over slice: "Entire dataset"\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + - ' AUC of 0.50\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + - ' advantage of 0.00') + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a' + ' positive predictive value of 1.00 on slice CORRECTLY_CLASSIFIED=' + 'True\n\n' + 'Best-performing attacks over slice: "CORRECTLY_CLASSIFIED=True"\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' AUC of 1.00\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' advantage of 1.00\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a' + ' positive predictive value of 1.00\n\n' + 'Best-performing attacks over slice: "Entire dataset"\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' AUC of 0.50\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' advantage of 0.00\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a' + ' positive predictive value of 0.50') def test_summary_without_slices(self): results = AttackResults( [self.perfect_classifier_result, self.random_classifier_result]) - self.assertEqual( + self.assertSequenceEqual( results.summary(by_slices=False), - 'Best-performing attacks over all slices\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + - ' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' + - ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + - ' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True') + 'Best-performing attacks over all slices\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a' + ' positive predictive value of 1.00 on slice CORRECTLY_CLASSIFIED=True') def test_save_load(self): results = AttackResults( @@ -645,6 +740,7 @@ class AttackResultsTest(absltest.TestCase): 'test size': [1, 1], 'attack type': ['THRESHOLD_ATTACK', 'THRESHOLD_ATTACK'], 'Attacker advantage': [1.0, 0.0], + 'Positive predictive value': [1.0, 0.5], 'AUC': [1.0, 0.5] }) pd.testing.assert_frame_equal(df, df_expected) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py index 1c4760d..2dace60 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py @@ -66,7 +66,7 @@ class UtilsTest(absltest.TestCase): self.assertLen(att_types, 2) self.assertLen(att_slices, 2) self.assertLen(att_metrics, 2) - self.assertLen(att_values, 2) + self.assertLen(att_values, 3) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index acbbf94..86197f9 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -98,7 +98,15 @@ def _run_trained_attack(attack_input: AttackInputData, # Generate ROC curves with scores. fpr, tpr, thresholds = metrics.roc_curve(labels, scores) - roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + # 'test_train_ratio' is the ratio of test data size to train data size. It is + # used to compute the Positive Predictive Value. + test_train_ratio = ((prepared_attacker_data.data_size.ntest) / + (prepared_attacker_data.data_size.ntrain)) + roc_curve = RocCurve( + tpr=tpr, + fpr=fpr, + thresholds=thresholds, + test_train_ratio=test_train_ratio) in_train_indices = (labels == 0) return SingleAttackResult( @@ -125,8 +133,15 @@ def _run_threshold_attack(attack_input: AttackInputData): fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate((loss_train, loss_test))) + # 'test_train_ratio' is the ratio of test data size to train data size. It is + # used to compute the Positive Predictive Value. + test_train_ratio = ntest / ntrain - roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + roc_curve = RocCurve( + tpr=tpr, + fpr=fpr, + thresholds=thresholds, + test_train_ratio=test_train_ratio) return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), @@ -147,8 +162,15 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData): np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate( (attack_input.get_entropy_train(), attack_input.get_entropy_test()))) + # 'test_train_ratio' is the ratio of test data size to train data size. It is + # used to compute the Positive Predictive Value. + test_train_ratio = ntest / ntrain - roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + roc_curve = RocCurve( + tpr=tpr, + fpr=fpr, + thresholds=thresholds, + test_train_ratio=test_train_ratio) return SingleAttackResult( slice_spec=_get_slice_spec(attack_input), @@ -250,9 +272,11 @@ def run_attacks(attack_input: AttackInputData, balance_attacker_training, min_num_samples, backend) if attack_result is not None: - logging.info('%s attack had an AUC=%s and attacker advantage=%s', - attack_type.name, attack_result.get_auc(), - attack_result.get_attacker_advantage()) + logging.info( + '%s attack had an AUC=%s, attacker advantage=%s and ' + 'positive predictive value=%s', attack_type.name, + attack_result.get_auc(), attack_result.get_attacker_advantage(), + attack_result.get_ppv()) attack_results.append(attack_result) privacy_report_metadata = _compute_missing_privacy_report_metadata( diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py index fc33af0..40c2e7b 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -87,6 +89,19 @@ class RunAttacksTest(parameterized.TestCase): self.assertLen(result.single_attack_results, 2) + @parameterized.named_parameters( + ('low_ratio', 100, 10), + ('ratio_1', 100, 100), + ('high_ratio', 100, 1000), + ) + def test_test_train_ratio(self, ntrain, ntest): + test_input = get_test_input(ntrain, ntest) + expected_test_train_ratio = ntest / ntrain + calculated_test_train_ratio = ( + test_input.get_test_size() / test_input.get_train_size()) + + self.assertEqual(expected_test_train_ratio, calculated_test_train_ratio) + def test_run_attacks_parallel_backend(self): result = mia.run_attacks( get_multilabel_test_input(100, 100), @@ -180,6 +195,26 @@ class RunAttacksTest(parameterized.TestCase): np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) + @mock.patch('sklearn.metrics.roc_curve') + def test_run_attack_threshold_entropy_small_tpr_fpr_correct_ppv( + self, patched_fn): + # sklearn.metrics.roc_curve returns (fpr, tpr, thresholds). + patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001, + 0.0002], [0.2, 0.4, 0.6]) + result = mia._run_attack( + AttackInputData( + entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]), + entropy_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6]), + force_multilabel_data=False), AttackType.THRESHOLD_ENTROPY_ATTACK) + # PPV = TPR / (TPR + test_train_ratio * FPR), except when both TPR and FPR + # are close to 0. Then PPV = 1/ (1 + test_train_ratio) + # With the above values, TPR / (TPR + test_train_ratio * FPR) = + # 0.1 / (0.1 + (6/6) * 0.2) = 0.333, + # 0.0001 / (0.0001 + (6/6) * 0.04) = 0.002493, + # and 1/ (1+ (6/6)) = 0.5. So PPV is the max of these three values, + # namely 0.5. + np.testing.assert_almost_equal(result.roc_curve.get_ppv(), 0.5, decimal=2) + def test_run_attack_by_slice(self): result = mia.run_attacks( get_test_input(100, 100), SlicingSpec(by_class=True), @@ -281,6 +316,38 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase): mia._get_multilabel_accuracy(predictions, labels), 5 / 9, places=6) self.assertIsNone(mia._get_accuracy(None, labels)) + def test_run_multilabel_attack_threshold_calculates_correct_ppv(self): + result = mia._run_attack( + AttackInputData( + loss_train=np.array([[0.1, 0.2], [1.3, 0.4], [0.5, 0.6], [0.9, + 0.6]]), + loss_test=np.array([[1.1, 1.2], [1.3, 0.4], [1.5, 1.6]]), + force_multilabel_data=True), AttackType.THRESHOLD_ATTACK) + + np.testing.assert_almost_equal(result.roc_curve.get_ppv(), 1.0, decimal=2) + + @mock.patch('sklearn.metrics.roc_curve') + def test_run_multilabel_attack_threshold_small_tpr_fpr_correct_ppv( + self, patched_fn): + # sklearn.metrics.roc_curve returns (fpr, tpr, thresholds). + patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001, + 0.0002], [0.2, 0.4, 0.6]) + result = mia._run_attack( + AttackInputData( + loss_train=np.array([[0.1, 0.2], [1.3, 0.4], [0.5, 0.6], [0.9, + 0.6]]), + loss_test=np.array([[1.1, 1.2], [1.3, 0.4], [1.5, 1.6]]), + force_multilabel_data=True), AttackType.THRESHOLD_ATTACK) + # PPV = TPR / (TPR + test_train_ratio * FPR), except when both TPR and FPR + # are close to 0. Then PPV = 1/ (1 + test_train_ratio) + # With the above values, TPR / (TPR + test_train_ratio * FPR) = + # 0.1 / (0.1 + (3/4) * 0.2) = 0.4, + # 0.0001 / (0.0001 + (3/4) * 0.04) = 0.003322, + # and 1/ (1+ 0.75) = 0.57142. So PPV is the max of these three values, + # namely 0.57142. + np.testing.assert_almost_equal( + result.roc_curve.get_ppv(), 0.57142, decimal=2) + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/privacy_report_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/privacy_report_test.py index adb7d7c..19e90d0 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/privacy_report_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/privacy_report_test.py @@ -38,7 +38,8 @@ class PrivacyReportTest(absltest.TestCase): roc_curve=RocCurve( tpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]), - thresholds=np.array([0, 1, 2])), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0), data_size=DataSize(ntrain=1, ntest=1)) # Classifier that achieves an AUC of 1.0. @@ -48,7 +49,8 @@ class PrivacyReportTest(absltest.TestCase): roc_curve=RocCurve( tpr=np.array([0.0, 1.0, 1.0]), fpr=np.array([1.0, 1.0, 0.0]), - thresholds=np.array([0, 1, 2])), + thresholds=np.array([0, 1, 2]), + test_train_ratio=1.0), data_size=DataSize(ntrain=1, ntest=1)) self.results_epoch_0 = AttackResults( diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py index 66d337e..cdef4f5 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py @@ -90,7 +90,7 @@ class UtilsTest(absltest.TestCase): self.assertLen(att_types, 2) self.assertLen(att_slices, 2) self.assertLen(att_metrics, 2) - self.assertLen(att_values, 2) + self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV def test_run_attack_on_tf_estimator_model(self): """Test the attack on the final models.""" @@ -110,7 +110,7 @@ class UtilsTest(absltest.TestCase): self.assertLen(att_types, 2) self.assertLen(att_slices, 2) self.assertLen(att_metrics, 2) - self.assertLen(att_values, 2) + self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV if __name__ == '__main__':