update risk score analysis

This commit is contained in:
Liwei Song 2020-12-17 15:18:02 -05:00
parent fd0ae811a6
commit b1993344cf

View file

@ -481,7 +481,7 @@ class SingleRiskScoreResult:
return np.array(meaningful_threshold_list), np.array(precision_list), np.array(recall_list) return np.array(meaningful_threshold_list), np.array(precision_list), np.array(recall_list)
def collect_results(self, threshold_list): def collect_results(self, threshold_list, return_roc_results=True):
""" The privacy risk score (from 0 to 1) represents each sample's probability of being in the training set. """ The privacy risk score (from 0 to 1) represents each sample's probability of being in the training set.
Usually, we choose a list of threshold values from 0.5 (uncertain of training or test) to 1 (100% certain of training) Usually, we choose a list of threshold values from 0.5 (uncertain of training or test) to 1 (100% certain of training)
to compute corresponding attack precision and recall. to compute corresponding attack precision and recall.
@ -493,6 +493,14 @@ class SingleRiskScoreResult:
for i in range(len(meaningful_threshold_list)): for i in range(len(meaningful_threshold_list)):
summary.append(' with %.5f as the threshold on privacy risk score, the precision-recall pair is (%.5f, %.5f)' % summary.append(' with %.5f as the threshold on privacy risk score, the precision-recall pair is (%.5f, %.5f)' %
(meaningful_threshold_list[i], precision_list[i], recall_list[i])) (meaningful_threshold_list[i], precision_list[i], recall_list[i]))
if return_roc_results:
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.ones(len(self.train_risk_scores)),
np.zeros(len(self.test_risk_scores)))),
np.concatenate((self.train_risk_scores, self.test_risk_scores)))
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
summary.append(' thresholding on privacy risk score achieved an AUC of %.2f' %(roc_curve.get_auc()))
summary.append(' thresholding on privacy risk score achieved an advantage of %.2f' %(roc_curve.get_attacker_advantage()))
return summary return summary