Refactors the pd_dataframe calculation to avoid hard-coded strings.
PiperOrigin-RevId: 334334080
This commit is contained in:
parent
c30c3fcb7a
commit
78d30a0424
3 changed files with 38 additions and 20 deletions
|
@ -25,14 +25,14 @@ from scipy import special
|
|||
from sklearn import metrics
|
||||
import tensorflow_privacy.privacy.membership_inference_attack.utils as utils
|
||||
|
||||
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
|
||||
ENTIRE_DATASET_SLICE_STR = 'Entire dataset'
|
||||
|
||||
|
||||
class SlicingFeature(enum.Enum):
|
||||
"""Enum with features by which slicing is available."""
|
||||
CLASS = 'class'
|
||||
PERCENTILE = 'percentile'
|
||||
CORRECTLY_CLASSIFIED = 'correctly_classfied'
|
||||
CORRECTLY_CLASSIFIED = 'correctly_classified'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -54,7 +54,7 @@ class SingleSliceSpec:
|
|||
|
||||
def __str__(self):
|
||||
if self.entire_dataset:
|
||||
return 'Entire dataset'
|
||||
return ENTIRE_DATASET_SLICE_STR
|
||||
|
||||
if self.feature == SlicingFeature.PERCENTILE:
|
||||
return 'Loss percentiles: %d-%d' % self.value
|
||||
|
@ -448,6 +448,17 @@ class PrivacyReportMetadata:
|
|||
epoch_num: int = None
|
||||
|
||||
|
||||
class AttackResultsDFColumns(enum.Enum):
|
||||
"""Columns for the Pandas DataFrame that stores AttackResults metrics."""
|
||||
SLICE_FEATURE = 'slice feature'
|
||||
SLICE_VALUE = 'slice value'
|
||||
ATTACK_TYPE = 'attack type'
|
||||
|
||||
def __str__(self):
|
||||
"""Returns 'slice value' instead of AttackResultsDFColumns.SLICE_VALUE."""
|
||||
return '%s' % self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackResults:
|
||||
"""Results from running multiple attacks."""
|
||||
|
@ -466,19 +477,19 @@ class AttackResults:
|
|||
for attack_result in self.single_attack_results:
|
||||
slice_spec = attack_result.slice_spec
|
||||
if slice_spec.entire_dataset:
|
||||
slice_feature, slice_value = 'entire_dataset', ''
|
||||
slice_feature, slice_value = str(slice_spec), ''
|
||||
else:
|
||||
slice_feature, slice_value = slice_spec.feature.value, slice_spec.value
|
||||
slice_features.append(str(slice_feature))
|
||||
slice_values.append(str(slice_value))
|
||||
attack_types.append(str(attack_result.attack_type.value))
|
||||
attack_types.append(str(attack_result.attack_type))
|
||||
advantages.append(float(attack_result.get_attacker_advantage()))
|
||||
aucs.append(float(attack_result.get_auc()))
|
||||
|
||||
df = pd.DataFrame({
|
||||
'slice feature': slice_features,
|
||||
'slice value': slice_values,
|
||||
'attack type': attack_types,
|
||||
str(AttackResultsDFColumns.SLICE_FEATURE): slice_features,
|
||||
str(AttackResultsDFColumns.SLICE_VALUE): slice_values,
|
||||
str(AttackResultsDFColumns.ATTACK_TYPE): attack_types,
|
||||
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
|
||||
str(PrivacyMetric.AUC): aucs
|
||||
})
|
||||
|
|
|
@ -302,13 +302,13 @@ class AttackResultsTest(absltest.TestCase):
|
|||
results = AttackResults(single_results)
|
||||
df = results.calculate_pd_dataframe()
|
||||
df_expected = pd.DataFrame({
|
||||
'slice feature': ['correctly_classfied', 'entire_dataset'],
|
||||
'slice feature': ['correctly_classified', 'Entire dataset'],
|
||||
'slice value': ['True', ''],
|
||||
'attack type': ['threshold', 'threshold'],
|
||||
'attack type': ['THRESHOLD_ATTACK', 'THRESHOLD_ATTACK'],
|
||||
'Attacker advantage': [1.0, 0.0],
|
||||
'AUC': [1.0, 0.5]
|
||||
})
|
||||
self.assertTrue(df.equals(df_expected))
|
||||
pd.testing.assert_frame_equal(df, df_expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -19,8 +19,15 @@ import matplotlib.pyplot as plt
|
|||
import pandas as pd
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||
|
||||
# Helper constants for DataFrame keys.
|
||||
LEGEND_LABEL_STR = 'legend label'
|
||||
EPOCH_STR = 'Epoch'
|
||||
TRAIN_ACCURACY_STR = 'Train accuracy'
|
||||
|
||||
|
||||
def plot_by_epochs(results: Iterable[AttackResults],
|
||||
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
||||
|
@ -73,17 +80,17 @@ def _calculate_combined_df_with_metadata(results: Iterable[AttackResults]):
|
|||
all_results_df = None
|
||||
for attack_results in results:
|
||||
attack_results_df = attack_results.calculate_pd_dataframe()
|
||||
attack_results_df = attack_results_df.loc[attack_results_df['slice feature']
|
||||
== 'entire_dataset']
|
||||
attack_results_df.insert(0, 'Epoch',
|
||||
attack_results_df = attack_results_df.loc[attack_results_df[str(
|
||||
AttackResultsDFColumns.SLICE_FEATURE)] == ENTIRE_DATASET_SLICE_STR]
|
||||
attack_results_df.insert(0, EPOCH_STR,
|
||||
attack_results.privacy_report_metadata.epoch_num)
|
||||
attack_results_df.insert(
|
||||
0, 'Train accuracy',
|
||||
0, TRAIN_ACCURACY_STR,
|
||||
attack_results.privacy_report_metadata.accuracy_train)
|
||||
attack_results_df.insert(
|
||||
0, 'legend label',
|
||||
0, LEGEND_LABEL_STR,
|
||||
attack_results.privacy_report_metadata.model_variant_label + ' - ' +
|
||||
attack_results_df['attack type'])
|
||||
attack_results_df[str(AttackResultsDFColumns.ATTACK_TYPE)])
|
||||
if all_results_df is None:
|
||||
all_results_df = attack_results_df
|
||||
else:
|
||||
|
@ -102,15 +109,15 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
|||
if len(privacy_metrics) == 1:
|
||||
axes = (axes,)
|
||||
for i, privacy_metric in enumerate(privacy_metrics):
|
||||
legend_labels = all_results_df['legend label'].unique()
|
||||
legend_labels = all_results_df[LEGEND_LABEL_STR].unique()
|
||||
for legend_label in legend_labels:
|
||||
single_label_results = all_results_df.loc[all_results_df['legend label']
|
||||
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
|
||||
== legend_label]
|
||||
axes[i].plot(single_label_results[x_axis_metric],
|
||||
single_label_results[str(privacy_metric)])
|
||||
axes[i].legend(legend_labels)
|
||||
axes[i].set_xlabel(x_axis_metric)
|
||||
axes[i].set_title('%s for Entire dataset' % str(privacy_metric))
|
||||
axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR)
|
||||
|
||||
return fig
|
||||
|
||||
|
|
Loading…
Reference in a new issue