diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 02f1c44..2601b1e 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -532,7 +532,7 @@ class SingleAttackResult: @dataclass class SingleRiskScoreResult: """Results from computing privacy risk scores. - this part is quite preliminary: it shows how to leverage privacy risk score to perform attacks with thresholding on risk score + this part shows how to leverage privacy risk score to perform attacks with thresholding on risk score """ # Data slice this result was calculated for. @@ -543,6 +543,10 @@ class SingleRiskScoreResult: test_risk_scores: np.ndarray def attack_with_varied_thresholds(self, threshold_list): + """ For each threshold value, we count how many training and test samples with privacy risk scores larger than the threshold + and further compute precision and recall values. + We skip the threshold value if it is larger than every sample's privacy risk score. + """ precision_list = [] recall_list = [] meaningful_threshold_list = [] @@ -553,9 +557,13 @@ class SingleRiskScoreResult: meaningful_threshold_list.append(threshold) precision_list.append(true_positive_normalized/(true_positive_normalized+false_positive_normalized+0.0)) recall_list.append(true_positive_normalized) - return meaningful_threshold_list, precision_list, recall_list + return np.array(meaningful_threshold_list), np.array(precision_list), np.array(recall_list) - def print_results(self, threshold_list=[1,0.9,0.8,0.7,0.6,0.5]): + def print_results(self, threshold_list=np.array([1,0.9,0.8,0.7,0.6,0.5])): + """ The privacy risk score (from 0 to 1) represents each sample's probability of being in the training set. + Here, we choose a list of threshold values from 0.5 (uncertain of training or test) to 1 (100% certain of training) + to compute corresponding attack precision and recall. + """ meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(threshold_list) for i in range(len(meaningful_threshold_list)): print(f"with {meaningful_threshold_list[i]} as the threshold on privacy risk score, the precision-recall pair is {(precision_list[i], recall_list[i])}")