Fix a bug in get_flattened_attack_metrics that types, slices, metrics do not
correspond to values because of PPV. PiperOrigin-RevId: 509274994
This commit is contained in:
parent
9ed34da715
commit
6ee988885a
3 changed files with 88 additions and 57 deletions
|
@ -542,7 +542,8 @@ class AttackInputData:
|
|||
self.logits_train is None and self.entropy_train is None and
|
||||
self.probs_train is None):
|
||||
raise ValueError(
|
||||
'At least one of labels, logits, losses, probabilities or entropy should be set'
|
||||
'At least one of labels, logits, losses, probabilities '
|
||||
'or entropy should be set.'
|
||||
)
|
||||
|
||||
if self.labels_train is not None and not _is_integer_type_array(
|
||||
|
@ -848,15 +849,18 @@ class SingleMembershipProbabilityResult:
|
|||
Returns:
|
||||
Summary string.
|
||||
"""
|
||||
meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(
|
||||
threshold_list)
|
||||
meaningful_threshold_list, precision_list, recall_list = (
|
||||
self.attack_with_varied_thresholds(threshold_list)
|
||||
)
|
||||
summary = []
|
||||
summary.append('\nMembership probability analysis over slice: \"%s\"' %
|
||||
str(self.slice_spec))
|
||||
for i in range(len(meaningful_threshold_list)):
|
||||
summary.append(
|
||||
' with %.4f as the threshold on membership probability, the precision-recall pair is (%.4f, %.4f)'
|
||||
% (meaningful_threshold_list[i], precision_list[i], recall_list[i]))
|
||||
' with %.4f as the threshold on membership probability, the'
|
||||
' precision-recall pair is (%.4f, %.4f)'
|
||||
% (meaningful_threshold_list[i], precision_list[i], recall_list[i])
|
||||
)
|
||||
if return_roc_results:
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.ones(len(self.train_membership_probs)),
|
||||
|
@ -875,8 +879,9 @@ class SingleMembershipProbabilityResult:
|
|||
' thresholding on membership probability achieved an AUC of %.2f' %
|
||||
(roc_curve.get_auc()))
|
||||
summary.append(
|
||||
' thresholding on membership probability achieved an advantage of %.2f'
|
||||
% (roc_curve.get_attacker_advantage()))
|
||||
' thresholding on membership probability achieved an advantage of'
|
||||
' %.2f' % (roc_curve.get_attacker_advantage())
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
|
@ -986,19 +991,29 @@ class AttackResults:
|
|||
max_auc_result_all = self.get_result_with_max_auc()
|
||||
summary.append('Best-performing attacks over all slices')
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved an AUC of %.2f on slice %s'
|
||||
% (max_auc_result_all.attack_type, max_auc_result_all.data_size.ntrain,
|
||||
max_auc_result_all.data_size.ntest, max_auc_result_all.get_auc(),
|
||||
max_auc_result_all.slice_spec))
|
||||
' %s (with %d training and %d test examples) achieved an AUC of %.2f'
|
||||
' on slice %s'
|
||||
% (
|
||||
max_auc_result_all.attack_type,
|
||||
max_auc_result_all.data_size.ntrain,
|
||||
max_auc_result_all.data_size.ntest,
|
||||
max_auc_result_all.get_auc(),
|
||||
max_auc_result_all.slice_spec,
|
||||
)
|
||||
)
|
||||
|
||||
max_advantage_result_all = self.get_result_with_max_attacker_advantage()
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved an advantage of %.2f on slice %s'
|
||||
% (max_advantage_result_all.attack_type,
|
||||
' %s (with %d training and %d test examples) achieved an advantage of'
|
||||
' %.2f on slice %s'
|
||||
% (
|
||||
max_advantage_result_all.attack_type,
|
||||
max_advantage_result_all.data_size.ntrain,
|
||||
max_advantage_result_all.data_size.ntest,
|
||||
max_advantage_result_all.get_attacker_advantage(),
|
||||
max_advantage_result_all.slice_spec))
|
||||
max_advantage_result_all.slice_spec,
|
||||
)
|
||||
)
|
||||
|
||||
max_ppv_result_all = self.get_result_with_max_ppv()
|
||||
summary.append(
|
||||
|
@ -1017,16 +1032,26 @@ class AttackResults:
|
|||
slice_str)
|
||||
max_auc_result = results.get_result_with_max_auc()
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved an AUC of %.2f'
|
||||
% (max_auc_result.attack_type, max_auc_result.data_size.ntrain,
|
||||
max_auc_result.data_size.ntest, max_auc_result.get_auc()))
|
||||
' %s (with %d training and %d test examples) achieved an AUC of'
|
||||
' %.2f'
|
||||
% (
|
||||
max_auc_result.attack_type,
|
||||
max_auc_result.data_size.ntrain,
|
||||
max_auc_result.data_size.ntest,
|
||||
max_auc_result.get_auc(),
|
||||
)
|
||||
)
|
||||
max_advantage_result = results.get_result_with_max_attacker_advantage()
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved an advantage of %.2f'
|
||||
% (max_advantage_result.attack_type,
|
||||
' %s (with %d training and %d test examples) achieved an advantage'
|
||||
' of %.2f'
|
||||
% (
|
||||
max_advantage_result.attack_type,
|
||||
max_advantage_result.data_size.ntrain,
|
||||
max_auc_result.data_size.ntest,
|
||||
max_advantage_result.get_attacker_advantage()))
|
||||
max_advantage_result.get_attacker_advantage(),
|
||||
)
|
||||
)
|
||||
max_ppv_result = results.get_result_with_max_ppv()
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved a positive '
|
||||
|
@ -1124,16 +1149,17 @@ def get_flattened_attack_metrics(results: AttackResults):
|
|||
types: a list of attack types
|
||||
slices: a list of slices
|
||||
attack_metrics: a list of metric names
|
||||
values: a list of metric values, i-th element correspond to properties[i]
|
||||
values: a list of metric values, i-th element corresponds to
|
||||
attack_metrics[i]
|
||||
"""
|
||||
types = []
|
||||
slices = []
|
||||
attack_metrics = []
|
||||
values = []
|
||||
for attack_result in results.single_attack_results:
|
||||
types += [str(attack_result.attack_type)] * 2
|
||||
slices += [str(attack_result.slice_spec)] * 2
|
||||
attack_metrics += ['adv', 'auc']
|
||||
types += [str(attack_result.attack_type)] * 3
|
||||
slices += [str(attack_result.slice_spec)] * 3
|
||||
attack_metrics += ['adv', 'auc', 'ppv']
|
||||
values += [
|
||||
float(attack_result.get_attacker_advantage()),
|
||||
float(attack_result.get_auc()),
|
||||
|
|
|
@ -61,11 +61,12 @@ class UtilsTest(absltest.TestCase):
|
|||
(self.test_data, self.test_labels),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
att_types, att_slices, att_metrics, att_values = (
|
||||
get_flattened_attack_metrics(results)
|
||||
)
|
||||
self.assertLen(att_types, 3)
|
||||
self.assertLen(att_slices, 3)
|
||||
self.assertLen(att_metrics, 3)
|
||||
self.assertLen(att_values, 3)
|
||||
|
||||
|
||||
|
|
|
@ -97,11 +97,12 @@ class UtilsTest(absltest.TestCase):
|
|||
self.sample_weight_test,
|
||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, data_structures.AttackResults)
|
||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
att_types, att_slices, att_metrics, att_values = (
|
||||
data_structures.get_flattened_attack_metrics(results)
|
||||
)
|
||||
self.assertLen(att_types, 3)
|
||||
self.assertLen(att_slices, 3)
|
||||
self.assertLen(att_metrics, 3)
|
||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||
|
||||
def test_run_attack_helper_with_sample_weights(self):
|
||||
|
@ -116,11 +117,12 @@ class UtilsTest(absltest.TestCase):
|
|||
out_train_sample_weight=self.sample_weight_test,
|
||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, data_structures.AttackResults)
|
||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
att_types, att_slices, att_metrics, att_values = (
|
||||
data_structures.get_flattened_attack_metrics(results)
|
||||
)
|
||||
self.assertLen(att_types, 3)
|
||||
self.assertLen(att_slices, 3)
|
||||
self.assertLen(att_metrics, 3)
|
||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||
|
||||
def test_run_attack_on_tf_estimator_model(self):
|
||||
|
@ -136,11 +138,12 @@ class UtilsTest(absltest.TestCase):
|
|||
input_fn_constructor,
|
||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, data_structures.AttackResults)
|
||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
att_types, att_slices, att_metrics, att_values = (
|
||||
data_structures.get_flattened_attack_metrics(results)
|
||||
)
|
||||
self.assertLen(att_types, 3)
|
||||
self.assertLen(att_slices, 3)
|
||||
self.assertLen(att_metrics, 3)
|
||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||
|
||||
def test_run_attack_on_tf_estimator_model_with_sample_weights(self):
|
||||
|
@ -157,11 +160,12 @@ class UtilsTest(absltest.TestCase):
|
|||
input_fn_constructor,
|
||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, data_structures.AttackResults)
|
||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||
results)
|
||||
self.assertLen(att_types, 2)
|
||||
self.assertLen(att_slices, 2)
|
||||
self.assertLen(att_metrics, 2)
|
||||
att_types, att_slices, att_metrics, att_values = (
|
||||
data_structures.get_flattened_attack_metrics(results)
|
||||
)
|
||||
self.assertLen(att_types, 3)
|
||||
self.assertLen(att_slices, 3)
|
||||
self.assertLen(att_metrics, 3)
|
||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue