Refactors the pd_dataframe calculation to avoid hard-coded strings.
PiperOrigin-RevId: 334334080
This commit is contained in:
parent
c30c3fcb7a
commit
78d30a0424
3 changed files with 38 additions and 20 deletions
|
@ -25,14 +25,14 @@ from scipy import special
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
import tensorflow_privacy.privacy.membership_inference_attack.utils as utils
|
import tensorflow_privacy.privacy.membership_inference_attack.utils as utils
|
||||||
|
|
||||||
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
|
ENTIRE_DATASET_SLICE_STR = 'Entire dataset'
|
||||||
|
|
||||||
|
|
||||||
class SlicingFeature(enum.Enum):
|
class SlicingFeature(enum.Enum):
|
||||||
"""Enum with features by which slicing is available."""
|
"""Enum with features by which slicing is available."""
|
||||||
CLASS = 'class'
|
CLASS = 'class'
|
||||||
PERCENTILE = 'percentile'
|
PERCENTILE = 'percentile'
|
||||||
CORRECTLY_CLASSIFIED = 'correctly_classfied'
|
CORRECTLY_CLASSIFIED = 'correctly_classified'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -54,7 +54,7 @@ class SingleSliceSpec:
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.entire_dataset:
|
if self.entire_dataset:
|
||||||
return 'Entire dataset'
|
return ENTIRE_DATASET_SLICE_STR
|
||||||
|
|
||||||
if self.feature == SlicingFeature.PERCENTILE:
|
if self.feature == SlicingFeature.PERCENTILE:
|
||||||
return 'Loss percentiles: %d-%d' % self.value
|
return 'Loss percentiles: %d-%d' % self.value
|
||||||
|
@ -448,6 +448,17 @@ class PrivacyReportMetadata:
|
||||||
epoch_num: int = None
|
epoch_num: int = None
|
||||||
|
|
||||||
|
|
||||||
|
class AttackResultsDFColumns(enum.Enum):
|
||||||
|
"""Columns for the Pandas DataFrame that stores AttackResults metrics."""
|
||||||
|
SLICE_FEATURE = 'slice feature'
|
||||||
|
SLICE_VALUE = 'slice value'
|
||||||
|
ATTACK_TYPE = 'attack type'
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Returns 'slice value' instead of AttackResultsDFColumns.SLICE_VALUE."""
|
||||||
|
return '%s' % self.value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttackResults:
|
class AttackResults:
|
||||||
"""Results from running multiple attacks."""
|
"""Results from running multiple attacks."""
|
||||||
|
@ -466,19 +477,19 @@ class AttackResults:
|
||||||
for attack_result in self.single_attack_results:
|
for attack_result in self.single_attack_results:
|
||||||
slice_spec = attack_result.slice_spec
|
slice_spec = attack_result.slice_spec
|
||||||
if slice_spec.entire_dataset:
|
if slice_spec.entire_dataset:
|
||||||
slice_feature, slice_value = 'entire_dataset', ''
|
slice_feature, slice_value = str(slice_spec), ''
|
||||||
else:
|
else:
|
||||||
slice_feature, slice_value = slice_spec.feature.value, slice_spec.value
|
slice_feature, slice_value = slice_spec.feature.value, slice_spec.value
|
||||||
slice_features.append(str(slice_feature))
|
slice_features.append(str(slice_feature))
|
||||||
slice_values.append(str(slice_value))
|
slice_values.append(str(slice_value))
|
||||||
attack_types.append(str(attack_result.attack_type.value))
|
attack_types.append(str(attack_result.attack_type))
|
||||||
advantages.append(float(attack_result.get_attacker_advantage()))
|
advantages.append(float(attack_result.get_attacker_advantage()))
|
||||||
aucs.append(float(attack_result.get_auc()))
|
aucs.append(float(attack_result.get_auc()))
|
||||||
|
|
||||||
df = pd.DataFrame({
|
df = pd.DataFrame({
|
||||||
'slice feature': slice_features,
|
str(AttackResultsDFColumns.SLICE_FEATURE): slice_features,
|
||||||
'slice value': slice_values,
|
str(AttackResultsDFColumns.SLICE_VALUE): slice_values,
|
||||||
'attack type': attack_types,
|
str(AttackResultsDFColumns.ATTACK_TYPE): attack_types,
|
||||||
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
|
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
|
||||||
str(PrivacyMetric.AUC): aucs
|
str(PrivacyMetric.AUC): aucs
|
||||||
})
|
})
|
||||||
|
|
|
@ -302,13 +302,13 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
results = AttackResults(single_results)
|
results = AttackResults(single_results)
|
||||||
df = results.calculate_pd_dataframe()
|
df = results.calculate_pd_dataframe()
|
||||||
df_expected = pd.DataFrame({
|
df_expected = pd.DataFrame({
|
||||||
'slice feature': ['correctly_classfied', 'entire_dataset'],
|
'slice feature': ['correctly_classified', 'Entire dataset'],
|
||||||
'slice value': ['True', ''],
|
'slice value': ['True', ''],
|
||||||
'attack type': ['threshold', 'threshold'],
|
'attack type': ['THRESHOLD_ATTACK', 'THRESHOLD_ATTACK'],
|
||||||
'Attacker advantage': [1.0, 0.0],
|
'Attacker advantage': [1.0, 0.0],
|
||||||
'AUC': [1.0, 0.5]
|
'AUC': [1.0, 0.5]
|
||||||
})
|
})
|
||||||
self.assertTrue(df.equals(df_expected))
|
pd.testing.assert_frame_equal(df, df_expected)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -19,8 +19,15 @@ import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||||
|
|
||||||
|
# Helper constants for DataFrame keys.
|
||||||
|
LEGEND_LABEL_STR = 'legend label'
|
||||||
|
EPOCH_STR = 'Epoch'
|
||||||
|
TRAIN_ACCURACY_STR = 'Train accuracy'
|
||||||
|
|
||||||
|
|
||||||
def plot_by_epochs(results: Iterable[AttackResults],
|
def plot_by_epochs(results: Iterable[AttackResults],
|
||||||
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
||||||
|
@ -73,17 +80,17 @@ def _calculate_combined_df_with_metadata(results: Iterable[AttackResults]):
|
||||||
all_results_df = None
|
all_results_df = None
|
||||||
for attack_results in results:
|
for attack_results in results:
|
||||||
attack_results_df = attack_results.calculate_pd_dataframe()
|
attack_results_df = attack_results.calculate_pd_dataframe()
|
||||||
attack_results_df = attack_results_df.loc[attack_results_df['slice feature']
|
attack_results_df = attack_results_df.loc[attack_results_df[str(
|
||||||
== 'entire_dataset']
|
AttackResultsDFColumns.SLICE_FEATURE)] == ENTIRE_DATASET_SLICE_STR]
|
||||||
attack_results_df.insert(0, 'Epoch',
|
attack_results_df.insert(0, EPOCH_STR,
|
||||||
attack_results.privacy_report_metadata.epoch_num)
|
attack_results.privacy_report_metadata.epoch_num)
|
||||||
attack_results_df.insert(
|
attack_results_df.insert(
|
||||||
0, 'Train accuracy',
|
0, TRAIN_ACCURACY_STR,
|
||||||
attack_results.privacy_report_metadata.accuracy_train)
|
attack_results.privacy_report_metadata.accuracy_train)
|
||||||
attack_results_df.insert(
|
attack_results_df.insert(
|
||||||
0, 'legend label',
|
0, LEGEND_LABEL_STR,
|
||||||
attack_results.privacy_report_metadata.model_variant_label + ' - ' +
|
attack_results.privacy_report_metadata.model_variant_label + ' - ' +
|
||||||
attack_results_df['attack type'])
|
attack_results_df[str(AttackResultsDFColumns.ATTACK_TYPE)])
|
||||||
if all_results_df is None:
|
if all_results_df is None:
|
||||||
all_results_df = attack_results_df
|
all_results_df = attack_results_df
|
||||||
else:
|
else:
|
||||||
|
@ -102,15 +109,15 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
||||||
if len(privacy_metrics) == 1:
|
if len(privacy_metrics) == 1:
|
||||||
axes = (axes,)
|
axes = (axes,)
|
||||||
for i, privacy_metric in enumerate(privacy_metrics):
|
for i, privacy_metric in enumerate(privacy_metrics):
|
||||||
legend_labels = all_results_df['legend label'].unique()
|
legend_labels = all_results_df[LEGEND_LABEL_STR].unique()
|
||||||
for legend_label in legend_labels:
|
for legend_label in legend_labels:
|
||||||
single_label_results = all_results_df.loc[all_results_df['legend label']
|
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
|
||||||
== legend_label]
|
== legend_label]
|
||||||
axes[i].plot(single_label_results[x_axis_metric],
|
axes[i].plot(single_label_results[x_axis_metric],
|
||||||
single_label_results[str(privacy_metric)])
|
single_label_results[str(privacy_metric)])
|
||||||
axes[i].legend(legend_labels)
|
axes[i].legend(legend_labels)
|
||||||
axes[i].set_xlabel(x_axis_metric)
|
axes[i].set_xlabel(x_axis_metric)
|
||||||
axes[i].set_title('%s for Entire dataset' % str(privacy_metric))
|
axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue