Returns attack results as a Pandas data frame.
PiperOrigin-RevId: 327675978
This commit is contained in:
parent
52c1f8fdfe
commit
d23772e163
3 changed files with 48 additions and 2 deletions
|
@ -20,8 +20,10 @@ from typing import Any, Iterable, Union
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
|
|
||||||
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
|
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,8 +336,31 @@ class AttackResults:
|
||||||
single_attack_results: Iterable[SingleAttackResult]
|
single_attack_results: Iterable[SingleAttackResult]
|
||||||
|
|
||||||
def calculate_pd_dataframe(self):
|
def calculate_pd_dataframe(self):
|
||||||
# returns all metrics as a Pandas DataFrame
|
"""Returns all metrics as a Pandas DataFrame."""
|
||||||
return
|
slice_features = []
|
||||||
|
slice_values = []
|
||||||
|
attack_types = []
|
||||||
|
advantages = []
|
||||||
|
aucs = []
|
||||||
|
|
||||||
|
for attack_result in self.single_attack_results:
|
||||||
|
slice_spec = attack_result.slice_spec
|
||||||
|
if slice_spec.entire_dataset:
|
||||||
|
slice_feature, slice_value = 'entire_dataset', ''
|
||||||
|
else:
|
||||||
|
slice_feature, slice_value = slice_spec.feature.value, slice_spec.value
|
||||||
|
slice_features.append(str(slice_feature))
|
||||||
|
slice_values.append(str(slice_value))
|
||||||
|
attack_types.append(str(attack_result.attack_type.value))
|
||||||
|
advantages.append(float(attack_result.get_attacker_advantage()))
|
||||||
|
aucs.append(float(attack_result.get_auc()))
|
||||||
|
|
||||||
|
df = pd.DataFrame({'slice feature': slice_features,
|
||||||
|
'slice value': slice_values,
|
||||||
|
'attack type': attack_types,
|
||||||
|
'attack advantage': advantages,
|
||||||
|
'roc auc': aucs})
|
||||||
|
return df
|
||||||
|
|
||||||
def summary(self, by_slices=False) -> str:
|
def summary(self, by_slices=False) -> str:
|
||||||
"""Provides a summary of the metrics.
|
"""Provides a summary of the metrics.
|
||||||
|
|
|
@ -19,6 +19,7 @@ import tempfile
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
@ -235,6 +236,20 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(repr(results), repr(loaded_results))
|
self.assertEqual(repr(results), repr(loaded_results))
|
||||||
|
|
||||||
|
def test_calculate_pd_dataframe(self):
|
||||||
|
single_results = [self.perfect_classifier_result,
|
||||||
|
self.random_classifier_result]
|
||||||
|
results = AttackResults(single_results)
|
||||||
|
df = results.calculate_pd_dataframe()
|
||||||
|
df_expected = pd.DataFrame({
|
||||||
|
'slice feature': ['correctly_classfied', 'entire_dataset'],
|
||||||
|
'slice value': ['True', ''],
|
||||||
|
'attack type': ['threshold', 'threshold'],
|
||||||
|
'attack advantage': [1.0, 0.0],
|
||||||
|
'roc auc': [1.0, 0.5]
|
||||||
|
})
|
||||||
|
self.assertTrue(df.equals(df_expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -22,6 +22,7 @@ import tempfile
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
from tensorflow.keras.utils import to_categorical
|
from tensorflow.keras.utils import to_categorical
|
||||||
|
@ -153,6 +154,11 @@ print(attack_results.summary(by_slices=False))
|
||||||
print("Summary by slices: \n")
|
print("Summary by slices: \n")
|
||||||
print(attack_results.summary(by_slices=True))
|
print(attack_results.summary(by_slices=True))
|
||||||
|
|
||||||
|
# Print pandas data frame
|
||||||
|
print("Pandas frame: \n")
|
||||||
|
pd.set_option("display.max_rows", None, "display.max_columns", None)
|
||||||
|
print(attack_results.calculate_pd_dataframe())
|
||||||
|
|
||||||
# Example of ROC curve plotting.
|
# Example of ROC curve plotting.
|
||||||
figure = plotting.plot_roc_curve(
|
figure = plotting.plot_roc_curve(
|
||||||
attack_results.single_attack_results[0].roc_curve)
|
attack_results.single_attack_results[0].roc_curve)
|
||||||
|
|
Loading…
Reference in a new issue