Add Positive Predictive Value as a metric for membership attack models performance on imbalanced data.

PiperOrigin-RevId: 461390184
This commit is contained in:
A. Unique TensorFlower 2022-07-16 16:30:17 -07:00
parent 328795aa36
commit 2b5d5b6ef5
7 changed files with 338 additions and 48 deletions

View file

@ -26,7 +26,10 @@ import numpy as np
import pandas as pd import pandas as pd
from scipy import special from scipy import special
from sklearn import metrics from sklearn import metrics
import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils as utils from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
# The minimum TPR or FPR below which they are considered equal.
_ABSOLUTE_TOLERANCE = 1e-3
ENTIRE_DATASET_SLICE_STR = 'Entire dataset' ENTIRE_DATASET_SLICE_STR = 'Entire dataset'
@ -116,7 +119,6 @@ class AttackType(enum.Enum):
K_NEAREST_NEIGHBORS = 'knn' K_NEAREST_NEIGHBORS = 'knn'
THRESHOLD_ATTACK = 'threshold' THRESHOLD_ATTACK = 'threshold'
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy' THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
TF_LOGISTIC_REGRESSION = 'tf_lr'
@property @property
def is_trained_attack(self): def is_trained_attack(self):
@ -133,6 +135,7 @@ class PrivacyMetric(enum.Enum):
"""An enum for the supported privacy risk metrics.""" """An enum for the supported privacy risk metrics."""
AUC = 'AUC' AUC = 'AUC'
ATTACKER_ADVANTAGE = 'Attacker advantage' ATTACKER_ADVANTAGE = 'Attacker advantage'
PPV = 'Positive predictive value'
def __str__(self): def __str__(self):
"""Returns 'AUC' instead of PrivacyMetric.AUC.""" """Returns 'AUC' instead of PrivacyMetric.AUC."""
@ -627,6 +630,11 @@ class RocCurve:
# False positive rates based on thresholds # False positive rates based on thresholds
fpr: np.ndarray fpr: np.ndarray
# Ratio of test to train set size.
# In Jayaraman et al. (https://arxiv.org/pdf/2005.10881.pdf) it is referred to
# as 'gamma' (see Table 1 for the definition).
test_train_ratio: np.float64
def get_auc(self): def get_auc(self):
"""Calculates area under curve (aka AUC).""" """Calculates area under curve (aka AUC)."""
return metrics.auc(self.fpr, self.tpr) return metrics.auc(self.fpr, self.tpr)
@ -643,12 +651,69 @@ class RocCurve:
""" """
return max(np.abs(self.tpr - self.fpr)) return max(np.abs(self.tpr - self.fpr))
def get_ppv(self) -> float:
"""Calculates Positive Predictive Value of the membership attacker.
The Positive Predictive Value (PPV) is the proportion of positive
predictions that are true positives. It can be expressed as PPV=TP/(TP+FP).
It was suggested by Jayaraman et al. (https://arxiv.org/pdf/2005.10881.pdf)
that this would be a suitable metric for membership attacks on datasets
where the number of samples from the training set and the number of samples
from the test set are very different. These are referred to as imbalanced
datasets.
Returns:
A single float number for the Positive Predictive Value.
"""
# The Positive Predictive Value (PPV) is the proportion of positive
# predictions that are true positives. It is expressed as PPV=TP/(TP+FP).
# It was suggested by Jayaraman et al.
# (https://arxiv.org/pdf/2005.10881.pdf) that this would be a suitable
# metric for membership attack models trained on datasets where the number
# of samples from the training set and the number of samples from the test
# set are very different. These are referred to as imbalanced datasets.
num = np.asarray(self.tpr)
den = num + np.asarray([r * self.test_train_ratio for r in self.fpr])
# There is a special case when both `num` and `den` are 0. Both would be 0
# when TPR and FPR are both 0, since test_train_ratio is strictly positive
# (exclude the case when there is no test set). Then TPR = 0 means that all
# positive (train set) examples are misclassified and FPR = 0 means that all
# negatives (test set) were correctly classified.
# Consider that when TPR and FPR are close to 0, TPR ~ FPR. Call this value
# 'R'. So the expression for PPV can be rewritten as:
# PPV = R / (R + test_train_ratio * R) = 1 / (1 + test_train_ratio).
# We can check this expression when test_train_ratio = 0, i.e. there is no
# test set, then PPV = 1 (perfect classification). When
# test_train_ratio >> 0, i,e, the test set size >> train set size, and
# PPV = 0 (perfect mis-classification).
# When test_train_ratio = 1, test and train sets are of the same size, and
# PPV = 0.5 (random guessing). This is because TPR = 0 means all positives
# are misclassified (i.e. classified as negatives) and FPR = 0 means all
# negatives are correctly classified (i.e. classified as neegatives).
# The normal case is when both `num` and `den` are not 0, and PPV is just
# the ratio of `num` to `den`.
# Find when `tpr` and `fpr` are 0.
tpr_is_0 = np.isclose(self.tpr, 0.0, atol=_ABSOLUTE_TOLERANCE)
fpr_is_0 = np.isclose(self.fpr, 0.0, atol=_ABSOLUTE_TOLERANCE)
tpr_and_fpr_both_0 = np.logical_and(tpr_is_0, fpr_is_0)
# PPV when both are zero is given by the expression below.
ppv_when_tpr_fpr_both_0 = 1. / (1. + self.test_train_ratio)
# PPV when one is not zero is given by the expression below.
ppv_when_one_of_tpr_fpr_not_0 = np.divide(
num, den, out=np.zeros_like(den), where=den != 0)
return np.max(
np.where(tpr_and_fpr_both_0, ppv_when_tpr_fpr_both_0,
ppv_when_one_of_tpr_fpr_not_0))
def __str__(self): def __str__(self):
"""Returns AUC and advantage metrics.""" """Returns AUC, advantage and PPV metrics."""
return '\n'.join([ return '\n'.join([
'RocCurve(', 'RocCurve(',
' AUC: %.2f' % self.get_auc(), ' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')' ' Attacker advantage: %.2f' % self.get_attacker_advantage(),
' Positive predictive value: %.2f' % self.get_ppv(), ')'
]) ])
@ -695,6 +760,11 @@ class SingleAttackResult:
def get_attacker_advantage(self): def get_attacker_advantage(self):
return self.roc_curve.get_attacker_advantage() return self.roc_curve.get_attacker_advantage()
def get_ppv(self) -> float:
if self.data_size.ntrain == 0:
raise ValueError('Size of the training data cannot be zero.')
return self.roc_curve.get_ppv()
def get_auc(self): def get_auc(self):
return self.roc_curve.get_auc() return self.roc_curve.get_auc()
@ -707,7 +777,8 @@ class SingleAttackResult:
(self.data_size.ntrain, self.data_size.ntest), (self.data_size.ntrain, self.data_size.ntest),
' AttackType: %s' % str(self.attack_type), ' AttackType: %s' % str(self.attack_type),
' AUC: %.2f' % self.get_auc(), ' AUC: %.2f' % self.get_auc(),
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')' ' Attacker advantage: %.2f' % self.get_attacker_advantage(),
' Positive Predictive Value: %.2f' % self.get_ppv(), ')'
]) ])
@ -791,7 +862,14 @@ class SingleMembershipProbabilityResult:
np.zeros(len(self.test_membership_probs)))), np.zeros(len(self.test_membership_probs)))),
np.concatenate( np.concatenate(
(self.train_membership_probs, self.test_membership_probs))) (self.train_membership_probs, self.test_membership_probs)))
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) ntrain = np.shape(self.train_membership_probs)[0]
ntest = np.shape(self.test_membership_probs)[0]
test_train_ratio = ntest / ntrain
roc_curve = RocCurve(
tpr=tpr,
fpr=fpr,
thresholds=thresholds,
test_train_ratio=test_train_ratio)
summary.append( summary.append(
' thresholding on membership probability achieved an AUC of %.2f' % ' thresholding on membership probability achieved an AUC of %.2f' %
(roc_curve.get_auc())) (roc_curve.get_auc()))
@ -860,6 +938,7 @@ class AttackResults:
data_size_test = [] data_size_test = []
attack_types = [] attack_types = []
advantages = [] advantages = []
ppvs = []
aucs = [] aucs = []
for attack_result in self.single_attack_results: for attack_result in self.single_attack_results:
@ -874,6 +953,7 @@ class AttackResults:
data_size_test.append(attack_result.data_size.ntest) data_size_test.append(attack_result.data_size.ntest)
attack_types.append(str(attack_result.attack_type)) attack_types.append(str(attack_result.attack_type))
advantages.append(float(attack_result.get_attacker_advantage())) advantages.append(float(attack_result.get_attacker_advantage()))
ppvs.append(float(attack_result.get_ppv()))
aucs.append(float(attack_result.get_auc())) aucs.append(float(attack_result.get_auc()))
df = pd.DataFrame({ df = pd.DataFrame({
@ -883,6 +963,7 @@ class AttackResults:
str(AttackResultsDFColumns.DATA_SIZE_TEST): data_size_test, str(AttackResultsDFColumns.DATA_SIZE_TEST): data_size_test,
str(AttackResultsDFColumns.ATTACK_TYPE): attack_types, str(AttackResultsDFColumns.ATTACK_TYPE): attack_types,
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages, str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
str(PrivacyMetric.PPV): ppvs,
str(PrivacyMetric.AUC): aucs str(PrivacyMetric.AUC): aucs
}) })
return df return df
@ -918,6 +999,14 @@ class AttackResults:
max_advantage_result_all.get_attacker_advantage(), max_advantage_result_all.get_attacker_advantage(),
max_advantage_result_all.slice_spec)) max_advantage_result_all.slice_spec))
max_ppv_result_all = self.get_result_with_max_ppv()
summary.append(
' %s (with %d training and %d test examples) achieved a positive '
'predictive value of %.2f on slice %s' %
(max_ppv_result_all.attack_type, max_ppv_result_all.data_size.ntrain,
max_ppv_result_all.data_size.ntest, max_ppv_result_all.get_ppv(),
max_ppv_result_all.slice_spec))
slice_dict = self._group_results_by_slice() slice_dict = self._group_results_by_slice()
if by_slices and len(slice_dict.keys()) > 1: if by_slices and len(slice_dict.keys()) > 1:
@ -937,6 +1026,12 @@ class AttackResults:
max_advantage_result.data_size.ntrain, max_advantage_result.data_size.ntrain,
max_auc_result.data_size.ntest, max_auc_result.data_size.ntest,
max_advantage_result.get_attacker_advantage())) max_advantage_result.get_attacker_advantage()))
max_ppv_result = results.get_result_with_max_ppv()
summary.append(
' %s (with %d training and %d test examples) achieved a positive '
'predictive value of %.2f' %
(max_ppv_result.attack_type, max_ppv_result.data_size.ntrain,
max_ppv_result.data_size.ntest, max_ppv_result.get_ppv()))
return '\n'.join(summary) return '\n'.join(summary)
@ -966,6 +1061,11 @@ class AttackResults:
result.get_attacker_advantage() for result in self.single_attack_results result.get_attacker_advantage() for result in self.single_attack_results
])] ])]
def get_result_with_max_ppv(self) -> SingleAttackResult:
"""Gets the result with max positive predictive value for all attacks and slices."""
return self.single_attack_results[np.argmax(
[result.get_ppv() for result in self.single_attack_results])]
def save(self, filepath): def save(self, filepath):
"""Saves self to a pickle file.""" """Saves self to a pickle file."""
with open(filepath, 'wb') as out: with open(filepath, 'wb') as out:
@ -1035,6 +1135,7 @@ def get_flattened_attack_metrics(results: AttackResults):
attack_metrics += ['adv', 'auc'] attack_metrics += ['adv', 'auc']
values += [ values += [
float(attack_result.get_attacker_advantage()), float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc()) float(attack_result.get_auc()),
float(attack_result.get_ppv()),
] ]
return types, slices, attack_metrics, values return types, slices, attack_metrics, values

View file

@ -14,6 +14,7 @@
import os import os
import tempfile import tempfile
from unittest import mock
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
@ -394,13 +395,14 @@ class AttackInputDataTest(parameterized.TestCase):
'"force_multilabel_data" is True but "multilabel_data" is False.') '"force_multilabel_data" is True but "multilabel_data" is False.')
class RocCurveTest(absltest.TestCase): class RocCurveTest(parameterized.TestCase):
def test_auc_random_classifier(self): def test_auc_random_classifier(self):
roc = RocCurve( roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
self.assertEqual(roc.get_auc(), 0.5) self.assertEqual(roc.get_auc(), 0.5)
@ -408,7 +410,8 @@ class RocCurveTest(absltest.TestCase):
roc = RocCurve( roc = RocCurve(
tpr=np.array([0.0, 1.0, 1.0]), tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]), fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
self.assertEqual(roc.get_auc(), 1.0) self.assertEqual(roc.get_auc(), 1.0)
@ -416,7 +419,8 @@ class RocCurveTest(absltest.TestCase):
roc = RocCurve( roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
self.assertEqual(roc.get_attacker_advantage(), 0.0) self.assertEqual(roc.get_attacker_advantage(), 0.0)
@ -424,10 +428,59 @@ class RocCurveTest(absltest.TestCase):
roc = RocCurve( roc = RocCurve(
tpr=np.array([0.0, 1.0, 1.0]), tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]), fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
self.assertEqual(roc.get_auc(), 1.0) self.assertEqual(roc.get_auc(), 1.0)
def test_ppv_random_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
self.assertEqual(roc.get_ppv(), 0.5)
def test_ppv_perfect_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
self.assertEqual(roc.get_ppv(), 1.0)
# Parameters to test: test-train ratio, expected PPV.
@parameterized.named_parameters(
('test_train_ratio_small', 0.001, 1.0),
('test_train_ratio_large', 1000.0, 0.0),
)
@mock.patch(
'tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures._ABSOLUTE_TOLERANCE',
1e-4)
def test_ppv_perfect_classifier_when_tpr_fpr_small(self, test_train_ratio,
expected_ppv):
roc = RocCurve(
tpr=np.array([0.00001, 0.0001, 0.002]),
fpr=np.array([0.00002, 0.0002, 0.002]),
thresholds=np.array([0, 1, 2]),
test_train_ratio=test_train_ratio)
np.testing.assert_allclose(roc.get_ppv(), expected_ppv, atol=1e-3)
@mock.patch(
'tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures._ABSOLUTE_TOLERANCE',
1e-4)
def test_ppv_random_classifier_when_tpr_fpr_small_and_test_train_is_1(self):
roc = RocCurve(
tpr=np.array([0.00001, 0.0001, 0.002]),
fpr=np.array([0.00002, 0.0002, 0.002]),
thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
np.testing.assert_allclose(roc.get_ppv(), 0.5, atol=1e-3)
class SingleAttackResultTest(absltest.TestCase): class SingleAttackResultTest(absltest.TestCase):
@ -436,7 +489,8 @@ class SingleAttackResultTest(absltest.TestCase):
roc = RocCurve( roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
result = SingleAttackResult( result = SingleAttackResult(
roc_curve=roc, roc_curve=roc,
@ -451,7 +505,8 @@ class SingleAttackResultTest(absltest.TestCase):
roc = RocCurve( roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])) thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
result = SingleAttackResult( result = SingleAttackResult(
roc_curve=roc, roc_curve=roc,
@ -461,6 +516,22 @@ class SingleAttackResultTest(absltest.TestCase):
self.assertEqual(result.get_attacker_advantage(), 0.0) self.assertEqual(result.get_attacker_advantage(), 0.0)
# Only a basic test, as this method calls RocCurve which is tested separately.
def test_ppv_random_classifier(self):
roc = RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0)
result = SingleAttackResult(
roc_curve=roc,
slice_spec=SingleSliceSpec(None),
attack_type=AttackType.THRESHOLD_ATTACK,
data_size=DataSize(ntrain=1, ntest=1))
self.assertEqual(result.get_ppv(), 0.5)
class SingleMembershipProbabilityResultTest(absltest.TestCase): class SingleMembershipProbabilityResultTest(absltest.TestCase):
@ -492,7 +563,8 @@ class AttackResultsCollectionTest(absltest.TestCase):
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])), thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0),
data_size=DataSize(ntrain=1, ntest=1)) data_size=DataSize(ntrain=1, ntest=1))
self.results_epoch_10 = AttackResults( self.results_epoch_10 = AttackResults(
@ -552,7 +624,8 @@ class AttackResultsTest(absltest.TestCase):
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 1.0, 1.0]), tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]), fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2])), thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0),
data_size=DataSize(ntrain=1, ntest=1)) data_size=DataSize(ntrain=1, ntest=1))
# ROC curve of a random classifier # ROC curve of a random classifier
@ -562,7 +635,8 @@ class AttackResultsTest(absltest.TestCase):
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])), thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0),
data_size=DataSize(ntrain=1, ntest=1)) data_size=DataSize(ntrain=1, ntest=1))
def test_get_result_with_max_auc_first(self): def test_get_result_with_max_auc_first(self):
@ -589,37 +663,58 @@ class AttackResultsTest(absltest.TestCase):
self.assertEqual(results.get_result_with_max_attacker_advantage(), self.assertEqual(results.get_result_with_max_attacker_advantage(),
self.perfect_classifier_result) self.perfect_classifier_result)
def test_get_result_with_max_positive_predictive_value_first(self):
results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual(results.get_result_with_max_ppv(),
self.perfect_classifier_result)
def test_get_result_with_max_positive_predictive_value_second(self):
results = AttackResults(
[self.random_classifier_result, self.perfect_classifier_result])
self.assertEqual(results.get_result_with_max_ppv(),
self.perfect_classifier_result)
def test_summary_by_slices(self): def test_summary_by_slices(self):
results = AttackResults( results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result]) [self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual( self.assertSequenceEqual(
results.summary(by_slices=True), results.summary(by_slices=True),
'Best-performing attacks over all slices\n' + 'Best-performing attacks over all slices\n' +
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' + ' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' +
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True\n\n' + ' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True\n'
'Best-performing attacks over slice: "CORRECTLY_CLASSIFIED=True"\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' positive predictive value of 1.00 on slice CORRECTLY_CLASSIFIED='
' AUC of 1.00\n' + 'True\n\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + 'Best-performing attacks over slice: "CORRECTLY_CLASSIFIED=True"\n'
' advantage of 1.00\n\n' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
'Best-performing attacks over slice: "Entire dataset"\n' + ' AUC of 1.00\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
' AUC of 0.50\n' + ' advantage of 1.00\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a'
' advantage of 0.00') ' positive predictive value of 1.00\n\n'
'Best-performing attacks over slice: "Entire dataset"\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
' AUC of 0.50\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
' advantage of 0.00\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a'
' positive predictive value of 0.50')
def test_summary_without_slices(self): def test_summary_without_slices(self):
results = AttackResults( results = AttackResults(
[self.perfect_classifier_result, self.random_classifier_result]) [self.perfect_classifier_result, self.random_classifier_result])
self.assertEqual( self.assertSequenceEqual(
results.summary(by_slices=False), results.summary(by_slices=False),
'Best-performing attacks over all slices\n' + 'Best-performing attacks over all slices\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' + ' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' + ' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an'
' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True') ' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True\n'
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved a'
' positive predictive value of 1.00 on slice CORRECTLY_CLASSIFIED=True')
def test_save_load(self): def test_save_load(self):
results = AttackResults( results = AttackResults(
@ -645,6 +740,7 @@ class AttackResultsTest(absltest.TestCase):
'test size': [1, 1], 'test size': [1, 1],
'attack type': ['THRESHOLD_ATTACK', 'THRESHOLD_ATTACK'], 'attack type': ['THRESHOLD_ATTACK', 'THRESHOLD_ATTACK'],
'Attacker advantage': [1.0, 0.0], 'Attacker advantage': [1.0, 0.0],
'Positive predictive value': [1.0, 0.5],
'AUC': [1.0, 0.5] 'AUC': [1.0, 0.5]
}) })
pd.testing.assert_frame_equal(df, df_expected) pd.testing.assert_frame_equal(df, df_expected)

View file

@ -66,7 +66,7 @@ class UtilsTest(absltest.TestCase):
self.assertLen(att_types, 2) self.assertLen(att_types, 2)
self.assertLen(att_slices, 2) self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2) self.assertLen(att_metrics, 2)
self.assertLen(att_values, 2) self.assertLen(att_values, 3)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -98,7 +98,15 @@ def _run_trained_attack(attack_input: AttackInputData,
# Generate ROC curves with scores. # Generate ROC curves with scores.
fpr, tpr, thresholds = metrics.roc_curve(labels, scores) fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) # 'test_train_ratio' is the ratio of test data size to train data size. It is
# used to compute the Positive Predictive Value.
test_train_ratio = ((prepared_attacker_data.data_size.ntest) /
(prepared_attacker_data.data_size.ntrain))
roc_curve = RocCurve(
tpr=tpr,
fpr=fpr,
thresholds=thresholds,
test_train_ratio=test_train_ratio)
in_train_indices = (labels == 0) in_train_indices = (labels == 0)
return SingleAttackResult( return SingleAttackResult(
@ -125,8 +133,15 @@ def _run_threshold_attack(attack_input: AttackInputData):
fpr, tpr, thresholds = metrics.roc_curve( fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate((np.zeros(ntrain), np.ones(ntest))),
np.concatenate((loss_train, loss_test))) np.concatenate((loss_train, loss_test)))
# 'test_train_ratio' is the ratio of test data size to train data size. It is
# used to compute the Positive Predictive Value.
test_train_ratio = ntest / ntrain
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) roc_curve = RocCurve(
tpr=tpr,
fpr=fpr,
thresholds=thresholds,
test_train_ratio=test_train_ratio)
return SingleAttackResult( return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input), slice_spec=_get_slice_spec(attack_input),
@ -147,8 +162,15 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate((np.zeros(ntrain), np.ones(ntest))),
np.concatenate( np.concatenate(
(attack_input.get_entropy_train(), attack_input.get_entropy_test()))) (attack_input.get_entropy_train(), attack_input.get_entropy_test())))
# 'test_train_ratio' is the ratio of test data size to train data size. It is
# used to compute the Positive Predictive Value.
test_train_ratio = ntest / ntrain
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) roc_curve = RocCurve(
tpr=tpr,
fpr=fpr,
thresholds=thresholds,
test_train_ratio=test_train_ratio)
return SingleAttackResult( return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input), slice_spec=_get_slice_spec(attack_input),
@ -250,9 +272,11 @@ def run_attacks(attack_input: AttackInputData,
balance_attacker_training, min_num_samples, balance_attacker_training, min_num_samples,
backend) backend)
if attack_result is not None: if attack_result is not None:
logging.info('%s attack had an AUC=%s and attacker advantage=%s', logging.info(
attack_type.name, attack_result.get_auc(), '%s attack had an AUC=%s, attacker advantage=%s and '
attack_result.get_attacker_advantage()) 'positive predictive value=%s', attack_type.name,
attack_result.get_auc(), attack_result.get_attacker_advantage(),
attack_result.get_ppv())
attack_results.append(attack_result) attack_results.append(attack_result)
privacy_report_metadata = _compute_missing_privacy_report_metadata( privacy_report_metadata = _compute_missing_privacy_report_metadata(

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest import mock
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
@ -87,6 +89,19 @@ class RunAttacksTest(parameterized.TestCase):
self.assertLen(result.single_attack_results, 2) self.assertLen(result.single_attack_results, 2)
@parameterized.named_parameters(
('low_ratio', 100, 10),
('ratio_1', 100, 100),
('high_ratio', 100, 1000),
)
def test_test_train_ratio(self, ntrain, ntest):
test_input = get_test_input(ntrain, ntest)
expected_test_train_ratio = ntest / ntrain
calculated_test_train_ratio = (
test_input.get_test_size() / test_input.get_train_size())
self.assertEqual(expected_test_train_ratio, calculated_test_train_ratio)
def test_run_attacks_parallel_backend(self): def test_run_attacks_parallel_backend(self):
result = mia.run_attacks( result = mia.run_attacks(
get_multilabel_test_input(100, 100), get_multilabel_test_input(100, 100),
@ -180,6 +195,26 @@ class RunAttacksTest(parameterized.TestCase):
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
@mock.patch('sklearn.metrics.roc_curve')
def test_run_attack_threshold_entropy_small_tpr_fpr_correct_ppv(
self, patched_fn):
# sklearn.metrics.roc_curve returns (fpr, tpr, thresholds).
patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001,
0.0002], [0.2, 0.4, 0.6])
result = mia._run_attack(
AttackInputData(
entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
entropy_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6]),
force_multilabel_data=False), AttackType.THRESHOLD_ENTROPY_ATTACK)
# PPV = TPR / (TPR + test_train_ratio * FPR), except when both TPR and FPR
# are close to 0. Then PPV = 1/ (1 + test_train_ratio)
# With the above values, TPR / (TPR + test_train_ratio * FPR) =
# 0.1 / (0.1 + (6/6) * 0.2) = 0.333,
# 0.0001 / (0.0001 + (6/6) * 0.04) = 0.002493,
# and 1/ (1+ (6/6)) = 0.5. So PPV is the max of these three values,
# namely 0.5.
np.testing.assert_almost_equal(result.roc_curve.get_ppv(), 0.5, decimal=2)
def test_run_attack_by_slice(self): def test_run_attack_by_slice(self):
result = mia.run_attacks( result = mia.run_attacks(
get_test_input(100, 100), SlicingSpec(by_class=True), get_test_input(100, 100), SlicingSpec(by_class=True),
@ -281,6 +316,38 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
mia._get_multilabel_accuracy(predictions, labels), 5 / 9, places=6) mia._get_multilabel_accuracy(predictions, labels), 5 / 9, places=6)
self.assertIsNone(mia._get_accuracy(None, labels)) self.assertIsNone(mia._get_accuracy(None, labels))
def test_run_multilabel_attack_threshold_calculates_correct_ppv(self):
result = mia._run_attack(
AttackInputData(
loss_train=np.array([[0.1, 0.2], [1.3, 0.4], [0.5, 0.6], [0.9,
0.6]]),
loss_test=np.array([[1.1, 1.2], [1.3, 0.4], [1.5, 1.6]]),
force_multilabel_data=True), AttackType.THRESHOLD_ATTACK)
np.testing.assert_almost_equal(result.roc_curve.get_ppv(), 1.0, decimal=2)
@mock.patch('sklearn.metrics.roc_curve')
def test_run_multilabel_attack_threshold_small_tpr_fpr_correct_ppv(
self, patched_fn):
# sklearn.metrics.roc_curve returns (fpr, tpr, thresholds).
patched_fn.return_value = ([0.2, 0.04, 0.0003], [0.1, 0.0001,
0.0002], [0.2, 0.4, 0.6])
result = mia._run_attack(
AttackInputData(
loss_train=np.array([[0.1, 0.2], [1.3, 0.4], [0.5, 0.6], [0.9,
0.6]]),
loss_test=np.array([[1.1, 1.2], [1.3, 0.4], [1.5, 1.6]]),
force_multilabel_data=True), AttackType.THRESHOLD_ATTACK)
# PPV = TPR / (TPR + test_train_ratio * FPR), except when both TPR and FPR
# are close to 0. Then PPV = 1/ (1 + test_train_ratio)
# With the above values, TPR / (TPR + test_train_ratio * FPR) =
# 0.1 / (0.1 + (3/4) * 0.2) = 0.4,
# 0.0001 / (0.0001 + (3/4) * 0.04) = 0.003322,
# and 1/ (1+ 0.75) = 0.57142. So PPV is the max of these three values,
# namely 0.57142.
np.testing.assert_almost_equal(
result.roc_curve.get_ppv(), 0.57142, decimal=2)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View file

@ -38,7 +38,8 @@ class PrivacyReportTest(absltest.TestCase):
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]), tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]), fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])), thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0),
data_size=DataSize(ntrain=1, ntest=1)) data_size=DataSize(ntrain=1, ntest=1))
# Classifier that achieves an AUC of 1.0. # Classifier that achieves an AUC of 1.0.
@ -48,7 +49,8 @@ class PrivacyReportTest(absltest.TestCase):
roc_curve=RocCurve( roc_curve=RocCurve(
tpr=np.array([0.0, 1.0, 1.0]), tpr=np.array([0.0, 1.0, 1.0]),
fpr=np.array([1.0, 1.0, 0.0]), fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2])), thresholds=np.array([0, 1, 2]),
test_train_ratio=1.0),
data_size=DataSize(ntrain=1, ntest=1)) data_size=DataSize(ntrain=1, ntest=1))
self.results_epoch_0 = AttackResults( self.results_epoch_0 = AttackResults(

View file

@ -90,7 +90,7 @@ class UtilsTest(absltest.TestCase):
self.assertLen(att_types, 2) self.assertLen(att_types, 2)
self.assertLen(att_slices, 2) self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2) self.assertLen(att_metrics, 2)
self.assertLen(att_values, 2) self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
def test_run_attack_on_tf_estimator_model(self): def test_run_attack_on_tf_estimator_model(self):
"""Test the attack on the final models.""" """Test the attack on the final models."""
@ -110,7 +110,7 @@ class UtilsTest(absltest.TestCase):
self.assertLen(att_types, 2) self.assertLen(att_types, 2)
self.assertLen(att_slices, 2) self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2) self.assertLen(att_metrics, 2)
self.assertLen(att_values, 2) self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
if __name__ == '__main__': if __name__ == '__main__':