Refactors the pd_dataframe calculation to avoid hard-coded strings.

PiperOrigin-RevId: 334334080
This commit is contained in:
David Marn 2020-09-29 02:15:13 -07:00 committed by A. Unique TensorFlower
parent c30c3fcb7a
commit 78d30a0424
3 changed files with 38 additions and 20 deletions

View file

@ -25,14 +25,14 @@ from scipy import special
from sklearn import metrics
import tensorflow_privacy.privacy.membership_inference_attack.utils as utils
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
ENTIRE_DATASET_SLICE_STR = 'Entire dataset'
class SlicingFeature(enum.Enum):
"""Enum with features by which slicing is available."""
CLASS = 'class'
PERCENTILE = 'percentile'
CORRECTLY_CLASSIFIED = 'correctly_classfied'
CORRECTLY_CLASSIFIED = 'correctly_classified'
@dataclass
@ -54,7 +54,7 @@ class SingleSliceSpec:
def __str__(self):
if self.entire_dataset:
return 'Entire dataset'
return ENTIRE_DATASET_SLICE_STR
if self.feature == SlicingFeature.PERCENTILE:
return 'Loss percentiles: %d-%d' % self.value
@ -448,6 +448,17 @@ class PrivacyReportMetadata:
epoch_num: int = None
class AttackResultsDFColumns(enum.Enum):
"""Columns for the Pandas DataFrame that stores AttackResults metrics."""
SLICE_FEATURE = 'slice feature'
SLICE_VALUE = 'slice value'
ATTACK_TYPE = 'attack type'
def __str__(self):
"""Returns 'slice value' instead of AttackResultsDFColumns.SLICE_VALUE."""
return '%s' % self.value
@dataclass
class AttackResults:
"""Results from running multiple attacks."""
@ -466,19 +477,19 @@ class AttackResults:
for attack_result in self.single_attack_results:
slice_spec = attack_result.slice_spec
if slice_spec.entire_dataset:
slice_feature, slice_value = 'entire_dataset', ''
slice_feature, slice_value = str(slice_spec), ''
else:
slice_feature, slice_value = slice_spec.feature.value, slice_spec.value
slice_features.append(str(slice_feature))
slice_values.append(str(slice_value))
attack_types.append(str(attack_result.attack_type.value))
attack_types.append(str(attack_result.attack_type))
advantages.append(float(attack_result.get_attacker_advantage()))
aucs.append(float(attack_result.get_auc()))
df = pd.DataFrame({
'slice feature': slice_features,
'slice value': slice_values,
'attack type': attack_types,
str(AttackResultsDFColumns.SLICE_FEATURE): slice_features,
str(AttackResultsDFColumns.SLICE_VALUE): slice_values,
str(AttackResultsDFColumns.ATTACK_TYPE): attack_types,
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
str(PrivacyMetric.AUC): aucs
})

View file

@ -302,13 +302,13 @@ class AttackResultsTest(absltest.TestCase):
results = AttackResults(single_results)
df = results.calculate_pd_dataframe()
df_expected = pd.DataFrame({
'slice feature': ['correctly_classfied', 'entire_dataset'],
'slice feature': ['correctly_classified', 'Entire dataset'],
'slice value': ['True', ''],
'attack type': ['threshold', 'threshold'],
'attack type': ['THRESHOLD_ATTACK', 'THRESHOLD_ATTACK'],
'Attacker advantage': [1.0, 0.0],
'AUC': [1.0, 0.5]
})
self.assertTrue(df.equals(df_expected))
pd.testing.assert_frame_equal(df, df_expected)
if __name__ == '__main__':

View file

@ -19,8 +19,15 @@ import matplotlib.pyplot as plt
import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
# Helper constants for DataFrame keys.
LEGEND_LABEL_STR = 'legend label'
EPOCH_STR = 'Epoch'
TRAIN_ACCURACY_STR = 'Train accuracy'
def plot_by_epochs(results: Iterable[AttackResults],
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
@ -73,17 +80,17 @@ def _calculate_combined_df_with_metadata(results: Iterable[AttackResults]):
all_results_df = None
for attack_results in results:
attack_results_df = attack_results.calculate_pd_dataframe()
attack_results_df = attack_results_df.loc[attack_results_df['slice feature']
== 'entire_dataset']
attack_results_df.insert(0, 'Epoch',
attack_results_df = attack_results_df.loc[attack_results_df[str(
AttackResultsDFColumns.SLICE_FEATURE)] == ENTIRE_DATASET_SLICE_STR]
attack_results_df.insert(0, EPOCH_STR,
attack_results.privacy_report_metadata.epoch_num)
attack_results_df.insert(
0, 'Train accuracy',
0, TRAIN_ACCURACY_STR,
attack_results.privacy_report_metadata.accuracy_train)
attack_results_df.insert(
0, 'legend label',
0, LEGEND_LABEL_STR,
attack_results.privacy_report_metadata.model_variant_label + ' - ' +
attack_results_df['attack type'])
attack_results_df[str(AttackResultsDFColumns.ATTACK_TYPE)])
if all_results_df is None:
all_results_df = attack_results_df
else:
@ -102,15 +109,15 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
if len(privacy_metrics) == 1:
axes = (axes,)
for i, privacy_metric in enumerate(privacy_metrics):
legend_labels = all_results_df['legend label'].unique()
legend_labels = all_results_df[LEGEND_LABEL_STR].unique()
for legend_label in legend_labels:
single_label_results = all_results_df.loc[all_results_df['legend label']
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
== legend_label]
axes[i].plot(single_label_results[x_axis_metric],
single_label_results[str(privacy_metric)])
axes[i].legend(legend_labels)
axes[i].set_xlabel(x_axis_metric)
axes[i].set_title('%s for Entire dataset' % str(privacy_metric))
axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR)
return fig