forked from 626_privacy/tensorflow_privacy
Merge pull request #146 from lwsong:master
PiperOrigin-RevId: 348448249
This commit is contained in:
commit
a3b64fd8f5
5 changed files with 1506 additions and 0 deletions
File diff suppressed because one or more lines are too long
|
@ -468,6 +468,111 @@ class SingleAttackResult:
|
|||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleMembershipProbabilityResult:
|
||||
"""Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
|
||||
|
||||
this part shows how to leverage membership probabilities to perform attacks
|
||||
with thresholding on them.
|
||||
"""
|
||||
|
||||
# Data slice this result was calculated for.
|
||||
slice_spec: SingleSliceSpec
|
||||
|
||||
train_membership_probs: np.ndarray
|
||||
|
||||
test_membership_probs: np.ndarray
|
||||
|
||||
def attack_with_varied_thresholds(self, threshold_list):
|
||||
"""Performs an attack with the specified thresholds.
|
||||
|
||||
For each threshold value, we count how many training and test samples with
|
||||
membership probabilities larger than the threshold and further compute
|
||||
precision and recall values. We skip the threshold value if it is larger
|
||||
than every sample's membership probability.
|
||||
|
||||
Args:
|
||||
threshold_list: List of provided thresholds
|
||||
|
||||
Returns:
|
||||
An array of attack results.
|
||||
"""
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.ones(len(self.train_membership_probs)),
|
||||
np.zeros(len(self.test_membership_probs)))),
|
||||
np.concatenate(
|
||||
(self.train_membership_probs, self.test_membership_probs)),
|
||||
drop_intermediate=False)
|
||||
|
||||
precision_list = []
|
||||
recall_list = []
|
||||
meaningful_threshold_list = []
|
||||
max_prob = max(self.train_membership_probs.max(),
|
||||
self.test_membership_probs.max())
|
||||
for threshold in threshold_list:
|
||||
if threshold <= max_prob:
|
||||
idx = np.argwhere(thresholds >= threshold)[-1][0]
|
||||
meaningful_threshold_list.append(threshold)
|
||||
precision_list.append(tpr[idx] / (tpr[idx] + fpr[idx]))
|
||||
recall_list.append(tpr[idx])
|
||||
|
||||
return np.array(meaningful_threshold_list), np.array(
|
||||
precision_list), np.array(recall_list)
|
||||
|
||||
def collect_results(self, threshold_list, return_roc_results=True):
|
||||
"""The membership probability (from 0 to 1) represents each sample's probability of being in the training set.
|
||||
|
||||
Usually, we choose a list of threshold values from 0.5 (uncertain of
|
||||
training or test) to 1 (100% certain of training)
|
||||
to compute corresponding attack precision and recall.
|
||||
|
||||
Args:
|
||||
threshold_list: List of provided thresholds
|
||||
return_roc_results: Whether to return ROC results
|
||||
|
||||
Returns:
|
||||
Summary string.
|
||||
"""
|
||||
meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(
|
||||
threshold_list)
|
||||
summary = []
|
||||
summary.append('\nMembership probability analysis over slice: \"%s\"' %
|
||||
str(self.slice_spec))
|
||||
for i in range(len(meaningful_threshold_list)):
|
||||
summary.append(
|
||||
' with %.4f as the threshold on membership probability, the precision-recall pair is (%.4f, %.4f)'
|
||||
% (meaningful_threshold_list[i], precision_list[i], recall_list[i]))
|
||||
if return_roc_results:
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.ones(len(self.train_membership_probs)),
|
||||
np.zeros(len(self.test_membership_probs)))),
|
||||
np.concatenate(
|
||||
(self.train_membership_probs, self.test_membership_probs)))
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
summary.append(
|
||||
' thresholding on membership probability achieved an AUC of %.2f' %
|
||||
(roc_curve.get_auc()))
|
||||
summary.append(
|
||||
' thresholding on membership probability achieved an advantage of %.2f'
|
||||
% (roc_curve.get_attacker_advantage()))
|
||||
return summary
|
||||
|
||||
|
||||
@dataclass
|
||||
class MembershipProbabilityResults:
|
||||
"""Membership probability results from multiple data slices."""
|
||||
|
||||
membership_prob_results: Iterable[SingleMembershipProbabilityResult]
|
||||
|
||||
def summary(self, threshold_list):
|
||||
"""Returns the summary of membership probability analyses on all slices."""
|
||||
summary = []
|
||||
for single_result in self.membership_prob_results:
|
||||
single_summary = single_result.collect_results(threshold_list)
|
||||
summary.extend(single_summary)
|
||||
return '\n'.join(summary)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrivacyReportMetadata:
|
||||
"""Metadata about the evaluated model.
|
||||
|
|
|
@ -28,6 +28,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
|
||||
|
@ -218,6 +219,25 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||
|
||||
|
||||
class SingleMembershipProbabilityResultTest(absltest.TestCase):
|
||||
|
||||
# Only a basic test to check the attack by setting a threshold on
|
||||
# membership probability.
|
||||
def test_attack_with_varied_thresholds(self):
|
||||
|
||||
result = SingleMembershipProbabilityResult(
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
train_membership_probs=np.array([0.91, 1, 0.92, 0.82, 0.75]),
|
||||
test_membership_probs=np.array([0.81, 0.7, 0.75, 0.25, 0.3]))
|
||||
|
||||
self.assertEqual(
|
||||
result.attack_with_varied_thresholds(
|
||||
threshold_list=np.array([0.8, 0.7]))[1].tolist(), [0.8, 0.625])
|
||||
self.assertEqual(
|
||||
result.attack_with_varied_thresholds(
|
||||
threshold_list=np.array([0.8, 0.7]))[2].tolist(), [0.8, 1])
|
||||
|
||||
|
||||
class AttackResultsCollectionTest(absltest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -27,10 +27,12 @@ from tensorflow_privacy.privacy.membership_inference_attack import models
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_single_slice_specs
|
||||
|
@ -182,6 +184,96 @@ def run_attacks(attack_input: AttackInputData,
|
|||
privacy_report_metadata=privacy_report_metadata)
|
||||
|
||||
|
||||
def _compute_membership_probability(
|
||||
attack_input: AttackInputData,
|
||||
num_bins: int = 15) -> SingleMembershipProbabilityResult:
|
||||
"""Computes each individual point's likelihood of being a member (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
|
||||
|
||||
For an individual sample, its privacy risk score is computed as the posterior
|
||||
probability of being in the training set
|
||||
after observing its prediction output by the target machine learning model.
|
||||
|
||||
Args:
|
||||
attack_input: input data for compute membership probability
|
||||
num_bins: the number of bins used to compute the training/test histogram
|
||||
|
||||
Returns:
|
||||
membership probability results
|
||||
"""
|
||||
|
||||
# Uses the provided loss or entropy. Otherwise computes the loss.
|
||||
if attack_input.loss_train is not None and attack_input.loss_test is not None:
|
||||
train_values = attack_input.loss_train
|
||||
test_values = attack_input.loss_test
|
||||
elif attack_input.entropy_train is not None and attack_input.entropy_test is not None:
|
||||
train_values = attack_input.entropy_train
|
||||
test_values = attack_input.entropy_test
|
||||
else:
|
||||
train_values = attack_input.get_loss_train()
|
||||
test_values = attack_input.get_loss_test()
|
||||
|
||||
# Compute the histogram in the log scale
|
||||
small_value = 1e-10
|
||||
train_values = np.maximum(train_values, small_value)
|
||||
test_values = np.maximum(test_values, small_value)
|
||||
|
||||
min_value = min(train_values.min(), test_values.min())
|
||||
max_value = max(train_values.max(), test_values.max())
|
||||
bins_hist = np.logspace(
|
||||
np.log10(min_value), np.log10(max_value), num_bins + 1)
|
||||
|
||||
train_hist, _ = np.histogram(train_values, bins=bins_hist)
|
||||
train_hist = train_hist / (len(train_values) + 0.0)
|
||||
train_hist_indices = np.fmin(
|
||||
np.digitize(train_values, bins=bins_hist), num_bins) - 1
|
||||
|
||||
test_hist, _ = np.histogram(test_values, bins=bins_hist)
|
||||
test_hist = test_hist / (len(test_values) + 0.0)
|
||||
test_hist_indices = np.fmin(
|
||||
np.digitize(test_values, bins=bins_hist), num_bins) - 1
|
||||
|
||||
combined_hist = train_hist + test_hist
|
||||
combined_hist[combined_hist == 0] = small_value
|
||||
membership_prob_list = train_hist / (combined_hist + 0.0)
|
||||
train_membership_probs = membership_prob_list[train_hist_indices]
|
||||
test_membership_probs = membership_prob_list[test_hist_indices]
|
||||
|
||||
return SingleMembershipProbabilityResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
train_membership_probs=train_membership_probs,
|
||||
test_membership_probs=test_membership_probs)
|
||||
|
||||
|
||||
def run_membership_probability_analysis(
|
||||
attack_input: AttackInputData,
|
||||
slicing_spec: SlicingSpec = None) -> MembershipProbabilityResults:
|
||||
"""Perform membership probability analysis on all given slice types.
|
||||
|
||||
Args:
|
||||
attack_input: input data for compute membership probabilities
|
||||
slicing_spec: specifies attack_input slices
|
||||
|
||||
Returns:
|
||||
the membership probability results.
|
||||
"""
|
||||
attack_input.validate()
|
||||
membership_prob_results = []
|
||||
|
||||
if slicing_spec is None:
|
||||
slicing_spec = SlicingSpec(entire_dataset=True)
|
||||
num_classes = None
|
||||
if slicing_spec.by_class:
|
||||
num_classes = attack_input.num_classes
|
||||
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
|
||||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
membership_prob_results.append(
|
||||
_compute_membership_probability(attack_input_slice))
|
||||
|
||||
return MembershipProbabilityResults(
|
||||
membership_prob_results=membership_prob_results)
|
||||
|
||||
|
||||
def _compute_missing_privacy_report_metadata(
|
||||
metadata: PrivacyReportMetadata,
|
||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||
|
|
|
@ -112,6 +112,17 @@ class RunAttacksTest(absltest.TestCase):
|
|||
# If accuracy is already present, simply return it.
|
||||
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||
|
||||
def test_run_compute_membership_probability_correct_probs(self):
|
||||
result = mia._compute_membership_probability(
|
||||
AttackInputData(
|
||||
loss_train=np.array([1, 1, 1, 10, 100]),
|
||||
loss_test=np.array([10, 100, 100, 1000, 10000])))
|
||||
|
||||
np.testing.assert_almost_equal(
|
||||
result.train_membership_probs, [1, 1, 1, 0.5, 0.33], decimal=2)
|
||||
np.testing.assert_almost_equal(
|
||||
result.test_membership_probs, [0.5, 0.33, 0.33, 0, 0], decimal=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue