Merge pull request #146 from lwsong:master

PiperOrigin-RevId: 348448249
This commit is contained in:
A. Unique TensorFlower 2020-12-21 04:36:33 -08:00
commit a3b64fd8f5
5 changed files with 1506 additions and 0 deletions

File diff suppressed because one or more lines are too long

View file

@ -468,6 +468,111 @@ class SingleAttackResult:
])
@dataclass
class SingleMembershipProbabilityResult:
"""Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
this part shows how to leverage membership probabilities to perform attacks
with thresholding on them.
"""
# Data slice this result was calculated for.
slice_spec: SingleSliceSpec
train_membership_probs: np.ndarray
test_membership_probs: np.ndarray
def attack_with_varied_thresholds(self, threshold_list):
"""Performs an attack with the specified thresholds.
For each threshold value, we count how many training and test samples with
membership probabilities larger than the threshold and further compute
precision and recall values. We skip the threshold value if it is larger
than every sample's membership probability.
Args:
threshold_list: List of provided thresholds
Returns:
An array of attack results.
"""
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.ones(len(self.train_membership_probs)),
np.zeros(len(self.test_membership_probs)))),
np.concatenate(
(self.train_membership_probs, self.test_membership_probs)),
drop_intermediate=False)
precision_list = []
recall_list = []
meaningful_threshold_list = []
max_prob = max(self.train_membership_probs.max(),
self.test_membership_probs.max())
for threshold in threshold_list:
if threshold <= max_prob:
idx = np.argwhere(thresholds >= threshold)[-1][0]
meaningful_threshold_list.append(threshold)
precision_list.append(tpr[idx] / (tpr[idx] + fpr[idx]))
recall_list.append(tpr[idx])
return np.array(meaningful_threshold_list), np.array(
precision_list), np.array(recall_list)
def collect_results(self, threshold_list, return_roc_results=True):
"""The membership probability (from 0 to 1) represents each sample's probability of being in the training set.
Usually, we choose a list of threshold values from 0.5 (uncertain of
training or test) to 1 (100% certain of training)
to compute corresponding attack precision and recall.
Args:
threshold_list: List of provided thresholds
return_roc_results: Whether to return ROC results
Returns:
Summary string.
"""
meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(
threshold_list)
summary = []
summary.append('\nMembership probability analysis over slice: \"%s\"' %
str(self.slice_spec))
for i in range(len(meaningful_threshold_list)):
summary.append(
' with %.4f as the threshold on membership probability, the precision-recall pair is (%.4f, %.4f)'
% (meaningful_threshold_list[i], precision_list[i], recall_list[i]))
if return_roc_results:
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.ones(len(self.train_membership_probs)),
np.zeros(len(self.test_membership_probs)))),
np.concatenate(
(self.train_membership_probs, self.test_membership_probs)))
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
summary.append(
' thresholding on membership probability achieved an AUC of %.2f' %
(roc_curve.get_auc()))
summary.append(
' thresholding on membership probability achieved an advantage of %.2f'
% (roc_curve.get_attacker_advantage()))
return summary
@dataclass
class MembershipProbabilityResults:
"""Membership probability results from multiple data slices."""
membership_prob_results: Iterable[SingleMembershipProbabilityResult]
def summary(self, threshold_list):
"""Returns the summary of membership probability analyses on all slices."""
summary = []
for single_result in self.membership_prob_results:
single_summary = single_result.collect_results(threshold_list)
summary.extend(single_summary)
return '\n'.join(summary)
@dataclass
class PrivacyReportMetadata:
"""Metadata about the evaluated model.

View file

@ -28,6 +28,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
@ -218,6 +219,25 @@ class SingleAttackResultTest(absltest.TestCase):
self.assertEqual(result.get_attacker_advantage(), 0.0)
class SingleMembershipProbabilityResultTest(absltest.TestCase):
# Only a basic test to check the attack by setting a threshold on
# membership probability.
def test_attack_with_varied_thresholds(self):
result = SingleMembershipProbabilityResult(
slice_spec=SingleSliceSpec(None),
train_membership_probs=np.array([0.91, 1, 0.92, 0.82, 0.75]),
test_membership_probs=np.array([0.81, 0.7, 0.75, 0.25, 0.3]))
self.assertEqual(
result.attack_with_varied_thresholds(
threshold_list=np.array([0.8, 0.7]))[1].tolist(), [0.8, 0.625])
self.assertEqual(
result.attack_with_varied_thresholds(
threshold_list=np.array([0.8, 0.7]))[2].tolist(), [0.8, 1])
class AttackResultsCollectionTest(absltest.TestCase):
def __init__(self, *args, **kwargs):

View file

@ -27,10 +27,12 @@ from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_single_slice_specs
@ -182,6 +184,96 @@ def run_attacks(attack_input: AttackInputData,
privacy_report_metadata=privacy_report_metadata)
def _compute_membership_probability(
attack_input: AttackInputData,
num_bins: int = 15) -> SingleMembershipProbabilityResult:
"""Computes each individual point's likelihood of being a member (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
For an individual sample, its privacy risk score is computed as the posterior
probability of being in the training set
after observing its prediction output by the target machine learning model.
Args:
attack_input: input data for compute membership probability
num_bins: the number of bins used to compute the training/test histogram
Returns:
membership probability results
"""
# Uses the provided loss or entropy. Otherwise computes the loss.
if attack_input.loss_train is not None and attack_input.loss_test is not None:
train_values = attack_input.loss_train
test_values = attack_input.loss_test
elif attack_input.entropy_train is not None and attack_input.entropy_test is not None:
train_values = attack_input.entropy_train
test_values = attack_input.entropy_test
else:
train_values = attack_input.get_loss_train()
test_values = attack_input.get_loss_test()
# Compute the histogram in the log scale
small_value = 1e-10
train_values = np.maximum(train_values, small_value)
test_values = np.maximum(test_values, small_value)
min_value = min(train_values.min(), test_values.min())
max_value = max(train_values.max(), test_values.max())
bins_hist = np.logspace(
np.log10(min_value), np.log10(max_value), num_bins + 1)
train_hist, _ = np.histogram(train_values, bins=bins_hist)
train_hist = train_hist / (len(train_values) + 0.0)
train_hist_indices = np.fmin(
np.digitize(train_values, bins=bins_hist), num_bins) - 1
test_hist, _ = np.histogram(test_values, bins=bins_hist)
test_hist = test_hist / (len(test_values) + 0.0)
test_hist_indices = np.fmin(
np.digitize(test_values, bins=bins_hist), num_bins) - 1
combined_hist = train_hist + test_hist
combined_hist[combined_hist == 0] = small_value
membership_prob_list = train_hist / (combined_hist + 0.0)
train_membership_probs = membership_prob_list[train_hist_indices]
test_membership_probs = membership_prob_list[test_hist_indices]
return SingleMembershipProbabilityResult(
slice_spec=_get_slice_spec(attack_input),
train_membership_probs=train_membership_probs,
test_membership_probs=test_membership_probs)
def run_membership_probability_analysis(
attack_input: AttackInputData,
slicing_spec: SlicingSpec = None) -> MembershipProbabilityResults:
"""Perform membership probability analysis on all given slice types.
Args:
attack_input: input data for compute membership probabilities
slicing_spec: specifies attack_input slices
Returns:
the membership probability results.
"""
attack_input.validate()
membership_prob_results = []
if slicing_spec is None:
slicing_spec = SlicingSpec(entire_dataset=True)
num_classes = None
if slicing_spec.by_class:
num_classes = attack_input.num_classes
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec)
membership_prob_results.append(
_compute_membership_probability(attack_input_slice))
return MembershipProbabilityResults(
membership_prob_results=membership_prob_results)
def _compute_missing_privacy_report_metadata(
metadata: PrivacyReportMetadata,
attack_input: AttackInputData) -> PrivacyReportMetadata:

View file

@ -112,6 +112,17 @@ class RunAttacksTest(absltest.TestCase):
# If accuracy is already present, simply return it.
self.assertIsNone(mia._get_accuracy(None, labels))
def test_run_compute_membership_probability_correct_probs(self):
result = mia._compute_membership_probability(
AttackInputData(
loss_train=np.array([1, 1, 1, 10, 100]),
loss_test=np.array([10, 100, 100, 1000, 10000])))
np.testing.assert_almost_equal(
result.train_membership_probs, [1, 1, 1, 0.5, 0.33], decimal=2)
np.testing.assert_almost_equal(
result.test_membership_probs, [0.5, 0.33, 0.33, 0, 0], decimal=2)
if __name__ == '__main__':
absltest.main()