forked from 626_privacy/tensorflow_privacy
Fix a bug in get_flattened_attack_metrics that types, slices, metrics do not
correspond to values because of PPV. PiperOrigin-RevId: 509274994
This commit is contained in:
parent
9ed34da715
commit
6ee988885a
3 changed files with 88 additions and 57 deletions
|
@ -542,7 +542,8 @@ class AttackInputData:
|
||||||
self.logits_train is None and self.entropy_train is None and
|
self.logits_train is None and self.entropy_train is None and
|
||||||
self.probs_train is None):
|
self.probs_train is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'At least one of labels, logits, losses, probabilities or entropy should be set'
|
'At least one of labels, logits, losses, probabilities '
|
||||||
|
'or entropy should be set.'
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.labels_train is not None and not _is_integer_type_array(
|
if self.labels_train is not None and not _is_integer_type_array(
|
||||||
|
@ -848,15 +849,18 @@ class SingleMembershipProbabilityResult:
|
||||||
Returns:
|
Returns:
|
||||||
Summary string.
|
Summary string.
|
||||||
"""
|
"""
|
||||||
meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(
|
meaningful_threshold_list, precision_list, recall_list = (
|
||||||
threshold_list)
|
self.attack_with_varied_thresholds(threshold_list)
|
||||||
|
)
|
||||||
summary = []
|
summary = []
|
||||||
summary.append('\nMembership probability analysis over slice: \"%s\"' %
|
summary.append('\nMembership probability analysis over slice: \"%s\"' %
|
||||||
str(self.slice_spec))
|
str(self.slice_spec))
|
||||||
for i in range(len(meaningful_threshold_list)):
|
for i in range(len(meaningful_threshold_list)):
|
||||||
summary.append(
|
summary.append(
|
||||||
' with %.4f as the threshold on membership probability, the precision-recall pair is (%.4f, %.4f)'
|
' with %.4f as the threshold on membership probability, the'
|
||||||
% (meaningful_threshold_list[i], precision_list[i], recall_list[i]))
|
' precision-recall pair is (%.4f, %.4f)'
|
||||||
|
% (meaningful_threshold_list[i], precision_list[i], recall_list[i])
|
||||||
|
)
|
||||||
if return_roc_results:
|
if return_roc_results:
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
np.concatenate((np.ones(len(self.train_membership_probs)),
|
np.concatenate((np.ones(len(self.train_membership_probs)),
|
||||||
|
@ -875,8 +879,9 @@ class SingleMembershipProbabilityResult:
|
||||||
' thresholding on membership probability achieved an AUC of %.2f' %
|
' thresholding on membership probability achieved an AUC of %.2f' %
|
||||||
(roc_curve.get_auc()))
|
(roc_curve.get_auc()))
|
||||||
summary.append(
|
summary.append(
|
||||||
' thresholding on membership probability achieved an advantage of %.2f'
|
' thresholding on membership probability achieved an advantage of'
|
||||||
% (roc_curve.get_attacker_advantage()))
|
' %.2f' % (roc_curve.get_attacker_advantage())
|
||||||
|
)
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
@ -986,19 +991,29 @@ class AttackResults:
|
||||||
max_auc_result_all = self.get_result_with_max_auc()
|
max_auc_result_all = self.get_result_with_max_auc()
|
||||||
summary.append('Best-performing attacks over all slices')
|
summary.append('Best-performing attacks over all slices')
|
||||||
summary.append(
|
summary.append(
|
||||||
' %s (with %d training and %d test examples) achieved an AUC of %.2f on slice %s'
|
' %s (with %d training and %d test examples) achieved an AUC of %.2f'
|
||||||
% (max_auc_result_all.attack_type, max_auc_result_all.data_size.ntrain,
|
' on slice %s'
|
||||||
max_auc_result_all.data_size.ntest, max_auc_result_all.get_auc(),
|
% (
|
||||||
max_auc_result_all.slice_spec))
|
max_auc_result_all.attack_type,
|
||||||
|
max_auc_result_all.data_size.ntrain,
|
||||||
|
max_auc_result_all.data_size.ntest,
|
||||||
|
max_auc_result_all.get_auc(),
|
||||||
|
max_auc_result_all.slice_spec,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
max_advantage_result_all = self.get_result_with_max_attacker_advantage()
|
max_advantage_result_all = self.get_result_with_max_attacker_advantage()
|
||||||
summary.append(
|
summary.append(
|
||||||
' %s (with %d training and %d test examples) achieved an advantage of %.2f on slice %s'
|
' %s (with %d training and %d test examples) achieved an advantage of'
|
||||||
% (max_advantage_result_all.attack_type,
|
' %.2f on slice %s'
|
||||||
|
% (
|
||||||
|
max_advantage_result_all.attack_type,
|
||||||
max_advantage_result_all.data_size.ntrain,
|
max_advantage_result_all.data_size.ntrain,
|
||||||
max_advantage_result_all.data_size.ntest,
|
max_advantage_result_all.data_size.ntest,
|
||||||
max_advantage_result_all.get_attacker_advantage(),
|
max_advantage_result_all.get_attacker_advantage(),
|
||||||
max_advantage_result_all.slice_spec))
|
max_advantage_result_all.slice_spec,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
max_ppv_result_all = self.get_result_with_max_ppv()
|
max_ppv_result_all = self.get_result_with_max_ppv()
|
||||||
summary.append(
|
summary.append(
|
||||||
|
@ -1017,16 +1032,26 @@ class AttackResults:
|
||||||
slice_str)
|
slice_str)
|
||||||
max_auc_result = results.get_result_with_max_auc()
|
max_auc_result = results.get_result_with_max_auc()
|
||||||
summary.append(
|
summary.append(
|
||||||
' %s (with %d training and %d test examples) achieved an AUC of %.2f'
|
' %s (with %d training and %d test examples) achieved an AUC of'
|
||||||
% (max_auc_result.attack_type, max_auc_result.data_size.ntrain,
|
' %.2f'
|
||||||
max_auc_result.data_size.ntest, max_auc_result.get_auc()))
|
% (
|
||||||
|
max_auc_result.attack_type,
|
||||||
|
max_auc_result.data_size.ntrain,
|
||||||
|
max_auc_result.data_size.ntest,
|
||||||
|
max_auc_result.get_auc(),
|
||||||
|
)
|
||||||
|
)
|
||||||
max_advantage_result = results.get_result_with_max_attacker_advantage()
|
max_advantage_result = results.get_result_with_max_attacker_advantage()
|
||||||
summary.append(
|
summary.append(
|
||||||
' %s (with %d training and %d test examples) achieved an advantage of %.2f'
|
' %s (with %d training and %d test examples) achieved an advantage'
|
||||||
% (max_advantage_result.attack_type,
|
' of %.2f'
|
||||||
|
% (
|
||||||
|
max_advantage_result.attack_type,
|
||||||
max_advantage_result.data_size.ntrain,
|
max_advantage_result.data_size.ntrain,
|
||||||
max_auc_result.data_size.ntest,
|
max_auc_result.data_size.ntest,
|
||||||
max_advantage_result.get_attacker_advantage()))
|
max_advantage_result.get_attacker_advantage(),
|
||||||
|
)
|
||||||
|
)
|
||||||
max_ppv_result = results.get_result_with_max_ppv()
|
max_ppv_result = results.get_result_with_max_ppv()
|
||||||
summary.append(
|
summary.append(
|
||||||
' %s (with %d training and %d test examples) achieved a positive '
|
' %s (with %d training and %d test examples) achieved a positive '
|
||||||
|
@ -1124,16 +1149,17 @@ def get_flattened_attack_metrics(results: AttackResults):
|
||||||
types: a list of attack types
|
types: a list of attack types
|
||||||
slices: a list of slices
|
slices: a list of slices
|
||||||
attack_metrics: a list of metric names
|
attack_metrics: a list of metric names
|
||||||
values: a list of metric values, i-th element correspond to properties[i]
|
values: a list of metric values, i-th element corresponds to
|
||||||
|
attack_metrics[i]
|
||||||
"""
|
"""
|
||||||
types = []
|
types = []
|
||||||
slices = []
|
slices = []
|
||||||
attack_metrics = []
|
attack_metrics = []
|
||||||
values = []
|
values = []
|
||||||
for attack_result in results.single_attack_results:
|
for attack_result in results.single_attack_results:
|
||||||
types += [str(attack_result.attack_type)] * 2
|
types += [str(attack_result.attack_type)] * 3
|
||||||
slices += [str(attack_result.slice_spec)] * 2
|
slices += [str(attack_result.slice_spec)] * 3
|
||||||
attack_metrics += ['adv', 'auc']
|
attack_metrics += ['adv', 'auc', 'ppv']
|
||||||
values += [
|
values += [
|
||||||
float(attack_result.get_attacker_advantage()),
|
float(attack_result.get_attacker_advantage()),
|
||||||
float(attack_result.get_auc()),
|
float(attack_result.get_auc()),
|
||||||
|
|
|
@ -61,11 +61,12 @@ class UtilsTest(absltest.TestCase):
|
||||||
(self.test_data, self.test_labels),
|
(self.test_data, self.test_labels),
|
||||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, AttackResults)
|
self.assertIsInstance(results, AttackResults)
|
||||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = (
|
||||||
results)
|
get_flattened_attack_metrics(results)
|
||||||
self.assertLen(att_types, 2)
|
)
|
||||||
self.assertLen(att_slices, 2)
|
self.assertLen(att_types, 3)
|
||||||
self.assertLen(att_metrics, 2)
|
self.assertLen(att_slices, 3)
|
||||||
|
self.assertLen(att_metrics, 3)
|
||||||
self.assertLen(att_values, 3)
|
self.assertLen(att_values, 3)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -97,11 +97,12 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.sample_weight_test,
|
self.sample_weight_test,
|
||||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, data_structures.AttackResults)
|
self.assertIsInstance(results, data_structures.AttackResults)
|
||||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = (
|
||||||
results)
|
data_structures.get_flattened_attack_metrics(results)
|
||||||
self.assertLen(att_types, 2)
|
)
|
||||||
self.assertLen(att_slices, 2)
|
self.assertLen(att_types, 3)
|
||||||
self.assertLen(att_metrics, 2)
|
self.assertLen(att_slices, 3)
|
||||||
|
self.assertLen(att_metrics, 3)
|
||||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||||
|
|
||||||
def test_run_attack_helper_with_sample_weights(self):
|
def test_run_attack_helper_with_sample_weights(self):
|
||||||
|
@ -116,11 +117,12 @@ class UtilsTest(absltest.TestCase):
|
||||||
out_train_sample_weight=self.sample_weight_test,
|
out_train_sample_weight=self.sample_weight_test,
|
||||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, data_structures.AttackResults)
|
self.assertIsInstance(results, data_structures.AttackResults)
|
||||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = (
|
||||||
results)
|
data_structures.get_flattened_attack_metrics(results)
|
||||||
self.assertLen(att_types, 2)
|
)
|
||||||
self.assertLen(att_slices, 2)
|
self.assertLen(att_types, 3)
|
||||||
self.assertLen(att_metrics, 2)
|
self.assertLen(att_slices, 3)
|
||||||
|
self.assertLen(att_metrics, 3)
|
||||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||||
|
|
||||||
def test_run_attack_on_tf_estimator_model(self):
|
def test_run_attack_on_tf_estimator_model(self):
|
||||||
|
@ -136,11 +138,12 @@ class UtilsTest(absltest.TestCase):
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, data_structures.AttackResults)
|
self.assertIsInstance(results, data_structures.AttackResults)
|
||||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = (
|
||||||
results)
|
data_structures.get_flattened_attack_metrics(results)
|
||||||
self.assertLen(att_types, 2)
|
)
|
||||||
self.assertLen(att_slices, 2)
|
self.assertLen(att_types, 3)
|
||||||
self.assertLen(att_metrics, 2)
|
self.assertLen(att_slices, 3)
|
||||||
|
self.assertLen(att_metrics, 3)
|
||||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||||
|
|
||||||
def test_run_attack_on_tf_estimator_model_with_sample_weights(self):
|
def test_run_attack_on_tf_estimator_model_with_sample_weights(self):
|
||||||
|
@ -157,11 +160,12 @@ class UtilsTest(absltest.TestCase):
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, data_structures.AttackResults)
|
self.assertIsInstance(results, data_structures.AttackResults)
|
||||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = (
|
||||||
results)
|
data_structures.get_flattened_attack_metrics(results)
|
||||||
self.assertLen(att_types, 2)
|
)
|
||||||
self.assertLen(att_slices, 2)
|
self.assertLen(att_types, 3)
|
||||||
self.assertLen(att_metrics, 2)
|
self.assertLen(att_slices, 3)
|
||||||
|
self.assertLen(att_metrics, 3)
|
||||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue