diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index e75cccd..3e47aad 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -542,7 +542,8 @@ class AttackInputData: self.logits_train is None and self.entropy_train is None and self.probs_train is None): raise ValueError( - 'At least one of labels, logits, losses, probabilities or entropy should be set' + 'At least one of labels, logits, losses, probabilities ' + 'or entropy should be set.' ) if self.labels_train is not None and not _is_integer_type_array( @@ -848,15 +849,18 @@ class SingleMembershipProbabilityResult: Returns: Summary string. """ - meaningful_threshold_list, precision_list, recall_list = self.attack_with_varied_thresholds( - threshold_list) + meaningful_threshold_list, precision_list, recall_list = ( + self.attack_with_varied_thresholds(threshold_list) + ) summary = [] summary.append('\nMembership probability analysis over slice: \"%s\"' % str(self.slice_spec)) for i in range(len(meaningful_threshold_list)): summary.append( - ' with %.4f as the threshold on membership probability, the precision-recall pair is (%.4f, %.4f)' - % (meaningful_threshold_list[i], precision_list[i], recall_list[i])) + ' with %.4f as the threshold on membership probability, the' + ' precision-recall pair is (%.4f, %.4f)' + % (meaningful_threshold_list[i], precision_list[i], recall_list[i]) + ) if return_roc_results: fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.ones(len(self.train_membership_probs)), @@ -875,8 +879,9 @@ class SingleMembershipProbabilityResult: ' thresholding on membership probability achieved an AUC of %.2f' % (roc_curve.get_auc())) summary.append( - ' thresholding on membership probability achieved an advantage of %.2f' - % (roc_curve.get_attacker_advantage())) + ' thresholding on membership probability achieved an advantage of' + ' %.2f' % (roc_curve.get_attacker_advantage()) + ) return summary @@ -986,19 +991,29 @@ class AttackResults: max_auc_result_all = self.get_result_with_max_auc() summary.append('Best-performing attacks over all slices') summary.append( - ' %s (with %d training and %d test examples) achieved an AUC of %.2f on slice %s' - % (max_auc_result_all.attack_type, max_auc_result_all.data_size.ntrain, - max_auc_result_all.data_size.ntest, max_auc_result_all.get_auc(), - max_auc_result_all.slice_spec)) + ' %s (with %d training and %d test examples) achieved an AUC of %.2f' + ' on slice %s' + % ( + max_auc_result_all.attack_type, + max_auc_result_all.data_size.ntrain, + max_auc_result_all.data_size.ntest, + max_auc_result_all.get_auc(), + max_auc_result_all.slice_spec, + ) + ) max_advantage_result_all = self.get_result_with_max_attacker_advantage() summary.append( - ' %s (with %d training and %d test examples) achieved an advantage of %.2f on slice %s' - % (max_advantage_result_all.attack_type, - max_advantage_result_all.data_size.ntrain, - max_advantage_result_all.data_size.ntest, - max_advantage_result_all.get_attacker_advantage(), - max_advantage_result_all.slice_spec)) + ' %s (with %d training and %d test examples) achieved an advantage of' + ' %.2f on slice %s' + % ( + max_advantage_result_all.attack_type, + max_advantage_result_all.data_size.ntrain, + max_advantage_result_all.data_size.ntest, + max_advantage_result_all.get_attacker_advantage(), + max_advantage_result_all.slice_spec, + ) + ) max_ppv_result_all = self.get_result_with_max_ppv() summary.append( @@ -1017,16 +1032,26 @@ class AttackResults: slice_str) max_auc_result = results.get_result_with_max_auc() summary.append( - ' %s (with %d training and %d test examples) achieved an AUC of %.2f' - % (max_auc_result.attack_type, max_auc_result.data_size.ntrain, - max_auc_result.data_size.ntest, max_auc_result.get_auc())) + ' %s (with %d training and %d test examples) achieved an AUC of' + ' %.2f' + % ( + max_auc_result.attack_type, + max_auc_result.data_size.ntrain, + max_auc_result.data_size.ntest, + max_auc_result.get_auc(), + ) + ) max_advantage_result = results.get_result_with_max_attacker_advantage() summary.append( - ' %s (with %d training and %d test examples) achieved an advantage of %.2f' - % (max_advantage_result.attack_type, - max_advantage_result.data_size.ntrain, - max_auc_result.data_size.ntest, - max_advantage_result.get_attacker_advantage())) + ' %s (with %d training and %d test examples) achieved an advantage' + ' of %.2f' + % ( + max_advantage_result.attack_type, + max_advantage_result.data_size.ntrain, + max_auc_result.data_size.ntest, + max_advantage_result.get_attacker_advantage(), + ) + ) max_ppv_result = results.get_result_with_max_ppv() summary.append( ' %s (with %d training and %d test examples) achieved a positive ' @@ -1121,19 +1146,20 @@ def get_flattened_attack_metrics(results: AttackResults): results: membership inference attack results. Returns: - types: a list of attack types - slices: a list of slices - attack_metrics: a list of metric names - values: a list of metric values, i-th element correspond to properties[i] + types: a list of attack types + slices: a list of slices + attack_metrics: a list of metric names + values: a list of metric values, i-th element corresponds to + attack_metrics[i] """ types = [] slices = [] attack_metrics = [] values = [] for attack_result in results.single_attack_results: - types += [str(attack_result.attack_type)] * 2 - slices += [str(attack_result.slice_spec)] * 2 - attack_metrics += ['adv', 'auc'] + types += [str(attack_result.attack_type)] * 3 + slices += [str(attack_result.slice_spec)] * 3 + attack_metrics += ['adv', 'auc', 'ppv'] values += [ float(attack_result.get_attacker_advantage()), float(attack_result.get_auc()), diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py index 2dace60..e5459cf 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation_test.py @@ -61,11 +61,12 @@ class UtilsTest(absltest.TestCase): (self.test_data, self.test_labels), attack_types=[AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, AttackResults) - att_types, att_slices, att_metrics, att_values = get_flattened_attack_metrics( - results) - self.assertLen(att_types, 2) - self.assertLen(att_slices, 2) - self.assertLen(att_metrics, 2) + att_types, att_slices, att_metrics, att_values = ( + get_flattened_attack_metrics(results) + ) + self.assertLen(att_types, 3) + self.assertLen(att_slices, 3) + self.assertLen(att_metrics, 3) self.assertLen(att_values, 3) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py index e26fa9f..cdc230d 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py @@ -97,11 +97,12 @@ class UtilsTest(absltest.TestCase): self.sample_weight_test, attack_types=[data_structures.AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, data_structures.AttackResults) - att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( - results) - self.assertLen(att_types, 2) - self.assertLen(att_slices, 2) - self.assertLen(att_metrics, 2) + att_types, att_slices, att_metrics, att_values = ( + data_structures.get_flattened_attack_metrics(results) + ) + self.assertLen(att_types, 3) + self.assertLen(att_slices, 3) + self.assertLen(att_metrics, 3) self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV def test_run_attack_helper_with_sample_weights(self): @@ -116,11 +117,12 @@ class UtilsTest(absltest.TestCase): out_train_sample_weight=self.sample_weight_test, attack_types=[data_structures.AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, data_structures.AttackResults) - att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( - results) - self.assertLen(att_types, 2) - self.assertLen(att_slices, 2) - self.assertLen(att_metrics, 2) + att_types, att_slices, att_metrics, att_values = ( + data_structures.get_flattened_attack_metrics(results) + ) + self.assertLen(att_types, 3) + self.assertLen(att_slices, 3) + self.assertLen(att_metrics, 3) self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV def test_run_attack_on_tf_estimator_model(self): @@ -136,11 +138,12 @@ class UtilsTest(absltest.TestCase): input_fn_constructor, attack_types=[data_structures.AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, data_structures.AttackResults) - att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( - results) - self.assertLen(att_types, 2) - self.assertLen(att_slices, 2) - self.assertLen(att_metrics, 2) + att_types, att_slices, att_metrics, att_values = ( + data_structures.get_flattened_attack_metrics(results) + ) + self.assertLen(att_types, 3) + self.assertLen(att_slices, 3) + self.assertLen(att_metrics, 3) self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV def test_run_attack_on_tf_estimator_model_with_sample_weights(self): @@ -157,11 +160,12 @@ class UtilsTest(absltest.TestCase): input_fn_constructor, attack_types=[data_structures.AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, data_structures.AttackResults) - att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( - results) - self.assertLen(att_types, 2) - self.assertLen(att_slices, 2) - self.assertLen(att_metrics, 2) + att_types, att_slices, att_metrics, att_values = ( + data_structures.get_flattened_attack_metrics(results) + ) + self.assertLen(att_types, 3) + self.assertLen(att_slices, 3) + self.assertLen(att_metrics, 3) self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV