Fix a bug in get_flattened_attack_metrics that types, slices, metrics do not

correspond to values because of PPV.

PiperOrigin-RevId: 509274994
This commit is contained in:
Shuang Song 2023-02-13 10:52:50 -08:00 committed by A. Unique TensorFlower
parent 9ed34da715
commit 6ee988885a
3 changed files with 88 additions and 57 deletions

View file

@ -542,7 +542,8 @@ class AttackInputData:
self.logits_train is None and self.entropy_train is None and
self.probs_train is None):
raise ValueError(
'At least one of labels, logits, losses, probabilities or entropy should be set'
'At least one of labels, logits, losses, probabilities '
'or entropy should be set.'
)
if self.labels_train is not None and not _is_integer_type_array(
@ -848,15 +849,18 @@ class SingleMembershipProbabilityResult:
Returns:
Summary string.
"""
meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds(
threshold_list)
meaningful_threshold_list, precision_list, recall_list = (
self.attack_with_varied_thresholds(threshold_list)
)
summary = []
summary.append('\nMembership probability analysis over slice: \"%s\"' %
str(self.slice_spec))
for i in range(len(meaningful_threshold_list)):
summary.append(
' with %.4f as the threshold on membership probability, the precision-recall pair is (%.4f, %.4f)'
% (meaningful_threshold_list[i], precision_list[i], recall_list[i]))
' with %.4f as the threshold on membership probability, the'
' precision-recall pair is (%.4f, %.4f)'
% (meaningful_threshold_list[i], precision_list[i], recall_list[i])
)
if return_roc_results:
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.ones(len(self.train_membership_probs)),
@ -875,8 +879,9 @@ class SingleMembershipProbabilityResult:
' thresholding on membership probability achieved an AUC of %.2f' %
(roc_curve.get_auc()))
summary.append(
' thresholding on membership probability achieved an advantage of %.2f'
% (roc_curve.get_attacker_advantage()))
' thresholding on membership probability achieved an advantage of'
' %.2f' % (roc_curve.get_attacker_advantage())
)
return summary
@ -986,19 +991,29 @@ class AttackResults:
max_auc_result_all = self.get_result_with_max_auc()
summary.append('Best-performing attacks over all slices')
summary.append(
' %s (with %d training and %d test examples) achieved an AUC of %.2f on slice %s'
% (max_auc_result_all.attack_type, max_auc_result_all.data_size.ntrain,
max_auc_result_all.data_size.ntest, max_auc_result_all.get_auc(),
max_auc_result_all.slice_spec))
' %s (with %d training and %d test examples) achieved an AUC of %.2f'
' on slice %s'
% (
max_auc_result_all.attack_type,
max_auc_result_all.data_size.ntrain,
max_auc_result_all.data_size.ntest,
max_auc_result_all.get_auc(),
max_auc_result_all.slice_spec,
)
)
max_advantage_result_all = self.get_result_with_max_attacker_advantage()
summary.append(
' %s (with %d training and %d test examples) achieved an advantage of %.2f on slice %s'
% (max_advantage_result_all.attack_type,
' %s (with %d training and %d test examples) achieved an advantage of'
' %.2f on slice %s'
% (
max_advantage_result_all.attack_type,
max_advantage_result_all.data_size.ntrain,
max_advantage_result_all.data_size.ntest,
max_advantage_result_all.get_attacker_advantage(),
max_advantage_result_all.slice_spec))
max_advantage_result_all.slice_spec,
)
)
max_ppv_result_all = self.get_result_with_max_ppv()
summary.append(
@ -1017,16 +1032,26 @@ class AttackResults:
slice_str)
max_auc_result = results.get_result_with_max_auc()
summary.append(
' %s (with %d training and %d test examples) achieved an AUC of %.2f'
% (max_auc_result.attack_type, max_auc_result.data_size.ntrain,
max_auc_result.data_size.ntest, max_auc_result.get_auc()))
' %s (with %d training and %d test examples) achieved an AUC of'
' %.2f'
% (
max_auc_result.attack_type,
max_auc_result.data_size.ntrain,
max_auc_result.data_size.ntest,
max_auc_result.get_auc(),
)
)
max_advantage_result = results.get_result_with_max_attacker_advantage()
summary.append(
' %s (with %d training and %d test examples) achieved an advantage of %.2f'
% (max_advantage_result.attack_type,
' %s (with %d training and %d test examples) achieved an advantage'
' of %.2f'
% (
max_advantage_result.attack_type,
max_advantage_result.data_size.ntrain,
max_auc_result.data_size.ntest,
max_advantage_result.get_attacker_advantage()))
max_advantage_result.get_attacker_advantage(),
)
)
max_ppv_result = results.get_result_with_max_ppv()
summary.append(
' %s (with %d training and %d test examples) achieved a positive '
@ -1124,16 +1149,17 @@ def get_flattened_attack_metrics(results: AttackResults):
types: a list of attack types
slices: a list of slices
attack_metrics: a list of metric names
values: a list of metric values, i-th element correspond to properties[i]
values: a list of metric values, i-th element corresponds to
attack_metrics[i]
"""
types = []
slices = []
attack_metrics = []
values = []
for attack_result in results.single_attack_results:
types += [str(attack_result.attack_type)] * 2
slices += [str(attack_result.slice_spec)] * 2
attack_metrics += ['adv', 'auc']
types += [str(attack_result.attack_type)] * 3
slices += [str(attack_result.slice_spec)] * 3
attack_metrics += ['adv', 'auc', 'ppv']
values += [
float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc()),

View file

@ -61,11 +61,12 @@ class UtilsTest(absltest.TestCase):
(self.test_data, self.test_labels),
attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults)
att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics(
results)
self.assertLen(att_types, 2)
self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2)
att_types, att_slices, att_metrics, att_values = (
get_flattened_attack_metrics(results)
)
self.assertLen(att_types, 3)
self.assertLen(att_slices, 3)
self.assertLen(att_metrics, 3)
self.assertLen(att_values, 3)

View file

@ -97,11 +97,12 @@ class UtilsTest(absltest.TestCase):
self.sample_weight_test,
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, data_structures.AttackResults)
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
results)
self.assertLen(att_types, 2)
self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2)
att_types, att_slices, att_metrics, att_values = (
data_structures.get_flattened_attack_metrics(results)
)
self.assertLen(att_types, 3)
self.assertLen(att_slices, 3)
self.assertLen(att_metrics, 3)
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
def test_run_attack_helper_with_sample_weights(self):
@ -116,11 +117,12 @@ class UtilsTest(absltest.TestCase):
out_train_sample_weight=self.sample_weight_test,
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, data_structures.AttackResults)
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
results)
self.assertLen(att_types, 2)
self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2)
att_types, att_slices, att_metrics, att_values = (
data_structures.get_flattened_attack_metrics(results)
)
self.assertLen(att_types, 3)
self.assertLen(att_slices, 3)
self.assertLen(att_metrics, 3)
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
def test_run_attack_on_tf_estimator_model(self):
@ -136,11 +138,12 @@ class UtilsTest(absltest.TestCase):
input_fn_constructor,
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, data_structures.AttackResults)
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
results)
self.assertLen(att_types, 2)
self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2)
att_types, att_slices, att_metrics, att_values = (
data_structures.get_flattened_attack_metrics(results)
)
self.assertLen(att_types, 3)
self.assertLen(att_slices, 3)
self.assertLen(att_metrics, 3)
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
def test_run_attack_on_tf_estimator_model_with_sample_weights(self):
@ -157,11 +160,12 @@ class UtilsTest(absltest.TestCase):
input_fn_constructor,
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, data_structures.AttackResults)
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
results)
self.assertLen(att_types, 2)
self.assertLen(att_slices, 2)
self.assertLen(att_metrics, 2)
att_types, att_slices, att_metrics, att_values = (
data_structures.get_flattened_attack_metrics(results)
)
self.assertLen(att_types, 3)
self.assertLen(att_slices, 3)
self.assertLen(att_metrics, 3)
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV