update code

This commit is contained in:
Liwei Song 2020-12-16 15:47:15 -05:00
parent bcee3f7a09
commit a4d108f270
3 changed files with 44 additions and 32 deletions

View file

@ -462,21 +462,28 @@ class SingleRiskScoreResult:
and further compute precision and recall values. and further compute precision and recall values.
We skip the threshold value if it is larger than every sample's privacy risk score. We skip the threshold value if it is larger than every sample's privacy risk score.
""" """
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.ones(len(self.train_risk_scores)),
np.zeros(len(self.test_risk_scores)))),
np.concatenate((self.train_risk_scores, self.test_risk_scores)),
drop_intermediate=False)
precision_list = [] precision_list = []
recall_list = [] recall_list = []
meaningful_threshold_list = [] meaningful_threshold_list = []
max_risk_score = max(train_risk_scores.max(), test_risk_scores.max())
for threshold in threshold_list: for threshold in threshold_list:
true_positive_normalized = np.sum(self.train_risk_scores>=threshold)/(len(self.train_risk_scores)+0.0) if threshold <= max_risk_score:
false_positive_normalized = np.sum(self.test_risk_scores>=threshold)/(len(self.test_risk_scores)+0.0) idx = np.argwhere(thresholds>=threshold)[-1][0]
if true_positive_normalized+false_positive_normalized>0: meaningful_threshold_list.append(threshold)
meaningful_threshold_list.append(threshold) precision_list.append(tpr[idx]/(tpr[idx]+fpr[idx]))
precision_list.append(true_positive_normalized/(true_positive_normalized+false_positive_normalized+0.0)) recall_list.append(tpr[idx])
recall_list.append(true_positive_normalized)
return np.array(meaningful_threshold_list), np.array(precision_list), np.array(recall_list) return np.array(meaningful_threshold_list), np.array(precision_list), np.array(recall_list)
def collect_results(self, threshold_list=np.array([1,0.9,0.8,0.7,0.6,0.5])): def collect_results(self, threshold_list):
""" The privacy risk score (from 0 to 1) represents each sample's probability of being in the training set. """ The privacy risk score (from 0 to 1) represents each sample's probability of being in the training set.
Here, we choose a list of threshold values from 0.5 (uncertain of training or test) to 1 (100% certain of training) Usually, we choose a list of threshold values from 0.5 (uncertain of training or test) to 1 (100% certain of training)
to compute corresponding attack precision and recall. to compute corresponding attack precision and recall.
""" """
meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(threshold_list) meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(threshold_list)
@ -496,14 +503,13 @@ class RiskScoreResults:
risk_score_results: Iterable[SingleRiskScoreResult] risk_score_results: Iterable[SingleRiskScoreResult]
def summary(self): def summary(self, threshold_list):
""" return the summary of privacy risk score analysis on all given data slices """ return the summary of privacy risk score analysis on all given data slices
""" """
summary = [] summary = []
for single_result in self.risk_score_results: for single_result in self.risk_score_results:
single_summary = single_result.collect_results() single_summary = single_result.collect_results(threshold_list)
for line in single_summary: summary.extend(single_summary)
summary.append(line)
return '\n'.join(summary) return '\n'.join(summary)

View file

@ -229,8 +229,8 @@ class SingleRiskScoreResultTest(absltest.TestCase):
train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]), train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]),
test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3])) test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3]))
self.assertEqual(result.attack_with_varied_thresholds(np.array([0.8,0.7]))[1].tolist(), [0.8,0.625]) self.assertEqual(result.attack_with_varied_thresholds(threshold_list=np.array([0.8,0.7]))[1].tolist(), [0.8,0.625])
self.assertEqual(result.attack_with_varied_thresholds(np.array([0.8,0.7]))[2].tolist(), [0.8,1]) self.assertEqual(result.attack_with_varied_thresholds(threshold_list=np.array([0.8,0.7]))[2].tolist(), [0.8,1])
class AttackResultsCollectionTest(absltest.TestCase): class AttackResultsCollectionTest(absltest.TestCase):

View file

@ -176,10 +176,13 @@ def run_attacks(attack_input: AttackInputData,
def _compute_privacy_risk_score(attack_input: AttackInputData, def _compute_privacy_risk_score(attack_input: AttackInputData,
num_bins: int = 15) -> SingleRiskScoreResult: num_bins: int = 15) -> SingleRiskScoreResult:
"""compute each individual point's likelihood of being a member (https://arxiv.org/abs/2003.10595) """Computes each individual point's likelihood of being a member (https://arxiv.org/abs/2003.10595).
For an individual sample, its privacy risk score is computed as the posterior probability of being in the training set
after observing its prediction output by the target machine learning model.
Args: Args:
attack_input: input data for compute privacy risk scores attack_input: input data for compute privacy risk scores
num_bins: the number of bins used to compute the training/test histogram; we set the default as 15 num_bins: the number of bins used to compute the training/test histogram
Returns: Returns:
privacy risk score results privacy risk score results
@ -188,28 +191,31 @@ def _compute_privacy_risk_score(attack_input: AttackInputData,
# If the loss or the entropy is provided, just use it; # If the loss or the entropy is provided, just use it;
# Otherwise, call the function to compute the loss (you can also choose to compute entropy) # Otherwise, call the function to compute the loss (you can also choose to compute entropy)
if attack_input.loss_train is not None and attack_input.loss_test is not None: if attack_input.loss_train is not None and attack_input.loss_test is not None:
train_values, test_values = attack_input.loss_train, attack_input.loss_test train_values = attack_input.loss_train
test_values = attack_input.loss_test
elif attack_input.entropy_train is not None and attack_input.entropy_test is not None: elif attack_input.entropy_train is not None and attack_input.entropy_test is not None:
train_values, test_values = attack_input.entropy_train, attack_input.entropy_test train_values = attack_input.entropy_train
test_values = attack_input.entropy_test
else: else:
train_values, test_values = attack_input.get_loss_train(), attack_input.get_loss_test() train_values = attack_input.get_loss_train()
test_values = attack_input.get_loss_test()
# Compute the histogram in the log scale # Compute the histogram in the log scale
small_value = 1e-10 small_value = 1e-10
train_log_values = np.log(np.maximum(train_values, small_value)) train_values = np.maximum(train_values, small_value)
test_log_values = np.log(np.maximum(test_values, small_value)) test_values = np.maximum(test_values, small_value)
min_log_value = np.amin(np.concatenate((train_log_values, test_log_values))) min_value = min(train_values.min(), test_values.min())
max_log_value = np.amax(np.concatenate((train_log_values, test_log_values))) max_value = max(train_values.max(), test_values.max())
bins_hist = np.linspace(min_log_value, max_log_value, num_bins+1) bins_hist = np.logspace(np.log10(min_value), np.log10(max_value), num_bins+1)
train_hist, _ = np.histogram(train_log_values, bins=bins_hist) train_hist, _ = np.histogram(train_values, bins=bins_hist)
train_hist = train_hist/(len(train_log_values)+0.0) train_hist = train_hist/(len(train_values)+0.0)
train_hist_indices = np.fmin(np.digitize(train_log_values, bins=bins_hist),num_bins)-1 train_hist_indices = np.fmin(np.digitize(train_values, bins=bins_hist),num_bins)-1
test_hist, _ = np.histogram(test_log_values, bins=bins_hist) test_hist, _ = np.histogram(test_values, bins=bins_hist)
test_hist = test_hist/(len(test_log_values)+0.0) test_hist = test_hist/(len(test_values)+0.0)
test_hist_indices = np.fmin(np.digitize(test_log_values, bins=bins_hist),num_bins)-1 test_hist_indices = np.fmin(np.digitize(test_values, bins=bins_hist),num_bins)-1
combined_hist = train_hist+test_hist combined_hist = train_hist+test_hist
combined_hist[combined_hist==0] = small_value combined_hist[combined_hist==0] = small_value
@ -224,8 +230,8 @@ def _compute_privacy_risk_score(attack_input: AttackInputData,
test_risk_scores=test_risk_scores) test_risk_scores=test_risk_scores)
def privacy_risk_score_analysis(attack_input: AttackInputData, def run_privacy_risk_score_analysis(attack_input: AttackInputData,
slicing_spec: SlicingSpec = None) -> RiskScoreResults: slicing_spec: SlicingSpec = None) -> RiskScoreResults:
"""Perform privacy risk score analysis on all given slice types """Perform privacy risk score analysis on all given slice types