Returns attack results as a Pandas data frame.

PiperOrigin-RevId: 327675978
This commit is contained in:
Shuang Song 2020-08-20 12:27:20 -07:00 committed by A. Unique TensorFlower
parent 52c1f8fdfe
commit d23772e163
3 changed files with 48 additions and 2 deletions

View file

@ -20,8 +20,10 @@ from typing import Any, Iterable, Union
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import pandas as pd
from sklearn import metrics from sklearn import metrics
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)' ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
@ -334,8 +336,31 @@ class AttackResults:
single_attack_results: Iterable[SingleAttackResult] single_attack_results: Iterable[SingleAttackResult]
def calculate_pd_dataframe(self): def calculate_pd_dataframe(self):
# returns all metrics as a Pandas DataFrame """Returns all metrics as a Pandas DataFrame."""
return slice_features = []
slice_values = []
attack_types = []
advantages = []
aucs = []
for attack_result in self.single_attack_results:
slice_spec = attack_result.slice_spec
if slice_spec.entire_dataset:
slice_feature, slice_value = 'entire_dataset', ''
else:
slice_feature, slice_value = slice_spec.feature.value, slice_spec.value
slice_features.append(str(slice_feature))
slice_values.append(str(slice_value))
attack_types.append(str(attack_result.attack_type.value))
advantages.append(float(attack_result.get_attacker_advantage()))
aucs.append(float(attack_result.get_auc()))
df = pd.DataFrame({'slice feature': slice_features,
'slice value': slice_values,
'attack type': attack_types,
'attack advantage': advantages,
'roc auc': aucs})
return df
def summary(self, by_slices=False) -> str: def summary(self, by_slices=False) -> str:
"""Provides a summary of the metrics. """Provides a summary of the metrics.

View file

@ -19,6 +19,7 @@ import tempfile
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
@ -235,6 +236,20 @@ class AttackResultsTest(absltest.TestCase):
self.assertEqual(repr(results), repr(loaded_results)) self.assertEqual(repr(results), repr(loaded_results))
def test_calculate_pd_dataframe(self):
single_results = [self.perfect_classifier_result,
self.random_classifier_result]
results = AttackResults(single_results)
df = results.calculate_pd_dataframe()
df_expected = pd.DataFrame({
'slice feature': ['correctly_classfied', 'entire_dataset'],
'slice value': ['True', ''],
'attack type': ['threshold', 'threshold'],
'attack advantage': [1.0, 0.0],
'roc auc': [1.0, 0.5]
})
self.assertTrue(df.equals(df_expected))
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View file

@ -22,6 +22,7 @@ import tempfile
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd
from tensorflow import keras from tensorflow import keras
from tensorflow.keras import layers from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical from tensorflow.keras.utils import to_categorical
@ -153,6 +154,11 @@ print(attack_results.summary(by_slices=False))
print("Summary by slices: \n") print("Summary by slices: \n")
print(attack_results.summary(by_slices=True)) print(attack_results.summary(by_slices=True))
# Print pandas data frame
print("Pandas frame: \n")
pd.set_option("display.max_rows", None, "display.max_columns", None)
print(attack_results.calculate_pd_dataframe())
# Example of ROC curve plotting. # Example of ROC curve plotting.
figure = plotting.plot_roc_curve( figure = plotting.plot_roc_curve(
attack_results.single_attack_results[0].roc_curve) attack_results.single_attack_results[0].roc_curve)