diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 32a22c7..b78a63f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -20,8 +20,10 @@ from typing import Any, Iterable, Union from dataclasses import dataclass import numpy as np +import pandas as pd from sklearn import metrics + ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)' @@ -334,8 +336,31 @@ class AttackResults: single_attack_results: Iterable[SingleAttackResult] def calculate_pd_dataframe(self): - # returns all metrics as a Pandas DataFrame - return + """Returns all metrics as a Pandas DataFrame.""" + slice_features = [] + slice_values = [] + attack_types = [] + advantages = [] + aucs = [] + + for attack_result in self.single_attack_results: + slice_spec = attack_result.slice_spec + if slice_spec.entire_dataset: + slice_feature, slice_value = 'entire_dataset', '' + else: + slice_feature, slice_value = slice_spec.feature.value, slice_spec.value + slice_features.append(str(slice_feature)) + slice_values.append(str(slice_value)) + attack_types.append(str(attack_result.attack_type.value)) + advantages.append(float(attack_result.get_attacker_advantage())) + aucs.append(float(attack_result.get_auc())) + + df = pd.DataFrame({'slice feature': slice_features, + 'slice value': slice_values, + 'attack type': attack_types, + 'attack advantage': advantages, + 'roc auc': aucs}) + return df def summary(self, by_slices=False) -> str: """Provides a summary of the metrics. diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 0ab4c3a..8314d95 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -19,6 +19,7 @@ import tempfile from absl.testing import absltest from absl.testing import parameterized import numpy as np +import pandas as pd from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType @@ -235,6 +236,20 @@ class AttackResultsTest(absltest.TestCase): self.assertEqual(repr(results), repr(loaded_results)) + def test_calculate_pd_dataframe(self): + single_results = [self.perfect_classifier_result, + self.random_classifier_result] + results = AttackResults(single_results) + df = results.calculate_pd_dataframe() + df_expected = pd.DataFrame({ + 'slice feature': ['correctly_classfied', 'entire_dataset'], + 'slice value': ['True', ''], + 'attack type': ['threshold', 'threshold'], + 'attack advantage': [1.0, 0.0], + 'roc auc': [1.0, 0.5] + }) + self.assertTrue(df.equals(df_expected)) + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index f13383f..090fd56 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -22,6 +22,7 @@ import tempfile import matplotlib.pyplot as plt import numpy as np +import pandas as pd from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.utils import to_categorical @@ -153,6 +154,11 @@ print(attack_results.summary(by_slices=False)) print("Summary by slices: \n") print(attack_results.summary(by_slices=True)) +# Print pandas data frame +print("Pandas frame: \n") +pd.set_option("display.max_rows", None, "display.max_columns", None) +print(attack_results.calculate_pd_dataframe()) + # Example of ROC curve plotting. figure = plotting.plot_roc_curve( attack_results.single_attack_results[0].roc_curve)