diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index cf2c71e..efc6859 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -15,6 +15,8 @@ # Lint as: python3 """Data structures representing attack inputs, configuration, outputs.""" import enum +import glob +import os import pickle from typing import Any, Iterable, Union @@ -577,6 +579,41 @@ class AttackResults: return pickle.load(inp) +@dataclass +class AttackResultsCollection: + """A collection of AttackResults.""" + attack_results_list: Iterable[AttackResults] + + def append(self, attack_results: AttackResults): + self.attack_results_list.append(attack_results) + + def save(self, dirname): + """Saves self to a pickle file.""" + for i, attack_results in enumerate(self.attack_results_list): + filepath = os.path.join(dirname, + _get_attack_results_filename(attack_results, i)) + + attack_results.save(filepath) + + @classmethod + def load(cls, dirname): + """Loads AttackResultsCollection from all files in a directory.""" + loaded_collection = AttackResultsCollection([]) + for filepath in sorted(glob.glob('%s/*' % dirname)): + with open(filepath, 'rb') as inp: + loaded_collection.attack_results_list.append(pickle.load(inp)) + return loaded_collection + + +def _get_attack_results_filename(attack_results: AttackResults, index: int): + """Creates a filename for a specific set of AttackResults.""" + metadata = attack_results.privacy_report_metadata + if metadata is not None: + return '%s_%s_%s.pickle' % (metadata.model_variant_label, + metadata.epoch_num, index) + return '%s.pickle' % index + + def get_flattened_attack_metrics(results: AttackResults): """Get flattened attack metrics. diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 2ec626e..d5c9d1d 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -23,7 +23,9 @@ import pandas as pd from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec @@ -206,6 +208,62 @@ class SingleAttackResultTest(absltest.TestCase): self.assertEqual(result.get_attacker_advantage(), 0.0) +class AttackResultsCollectionTest(absltest.TestCase): + + def __init__(self, *args, **kwargs): + super(AttackResultsCollectionTest, self).__init__(*args, **kwargs) + + self.some_attack_result = SingleAttackResult( + slice_spec=SingleSliceSpec(None), + attack_type=AttackType.THRESHOLD_ATTACK, + roc_curve=RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([0.0, 0.5, 1.0]), + thresholds=np.array([0, 1, 2]))) + + self.results_epoch_10 = AttackResults( + single_attack_results=[self.some_attack_result], + privacy_report_metadata=PrivacyReportMetadata( + accuracy_train=0.4, + accuracy_test=0.3, + epoch_num=10, + model_variant_label='default')) + + self.results_epoch_15 = AttackResults( + single_attack_results=[self.some_attack_result], + privacy_report_metadata=PrivacyReportMetadata( + accuracy_train=0.5, + accuracy_test=0.4, + epoch_num=15, + model_variant_label='default')) + + self.attack_results_no_metadata = AttackResults( + single_attack_results=[self.some_attack_result]) + + self.collection_with_metadata = AttackResultsCollection( + [self.results_epoch_10, self.results_epoch_15]) + + self.collection_no_metadata = AttackResultsCollection( + [self.attack_results_no_metadata, self.attack_results_no_metadata]) + + def test_save_with_metadata(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.collection_with_metadata.save(tmpdirname) + loaded_collection = AttackResultsCollection.load(tmpdirname) + + self.assertEqual( + repr(self.collection_with_metadata), repr(loaded_collection)) + self.assertLen(loaded_collection.attack_results_list, 2) + + def test_save_no_metadata(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.collection_no_metadata.save(tmpdirname) + loaded_collection = AttackResultsCollection.load(tmpdirname) + + self.assertEqual(repr(self.collection_no_metadata), repr(loaded_collection)) + self.assertLen(loaded_collection.attack_results_list, 2) + + class AttackResultsTest(absltest.TestCase): perfect_classifier_result: SingleAttackResult diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index cdea468..1f2d2af 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -32,6 +32,7 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ @@ -119,7 +120,7 @@ def crossentropy(true_labels, predictions): def main(unused_argv): - epoch_results = [] + epoch_results = AttackResultsCollection([]) num_epochs = 2 models = { diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py index cff25d3..b12baf1 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py @@ -19,6 +19,7 @@ import matplotlib.pyplot as plt import pandas as pd from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric @@ -29,7 +30,7 @@ EPOCH_STR = 'Epoch' TRAIN_ACCURACY_STR = 'Train accuracy' -def plot_by_epochs(results: Iterable[AttackResults], +def plot_by_epochs(results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure: """Plots privacy vulnerabilities vs epoch numbers for a single model variant. @@ -43,8 +44,9 @@ def plot_by_epochs(results: Iterable[AttackResults], A pyplot figure with privacy vs accuracy plots. """ - _validate_results(results) - all_results_df = _calculate_combined_df_with_metadata(results) + _validate_results(results.attack_results_list) + all_results_df = _calculate_combined_df_with_metadata( + results.attack_results_list) return _generate_subplots( all_results_df=all_results_df, x_axis_metric='Epoch', @@ -53,7 +55,7 @@ def plot_by_epochs(results: Iterable[AttackResults], def plot_privacy_vs_accuracy_single_model( - results: Iterable[AttackResults], privacy_metrics: Iterable[PrivacyMetric]): + results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]): """Plots privacy vulnerabilities vs accuracy plots for a single model variant. In case multiple privacy metrics are specified, the plot will feature @@ -66,8 +68,9 @@ def plot_privacy_vs_accuracy_single_model( A pyplot figure with privacy vs accuracy plots. """ - _validate_results(results) - all_results_df = _calculate_combined_df_with_metadata(results) + _validate_results(results.attack_results_list) + all_results_df = _calculate_combined_df_with_metadata( + results.attack_results_list) return _generate_subplots( all_results_df=all_results_df, x_axis_metric='Train accuracy', diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py index 3a03f75..a38383c 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py @@ -20,6 +20,7 @@ import numpy as np from tensorflow_privacy.privacy.membership_inference_attack import privacy_report from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ PrivacyReportMetadata @@ -80,12 +81,14 @@ class PrivacyReportTest(absltest.TestCase): def test_plot_by_epochs_no_metadata(self): # Raise error if metadata is missing - self.assertRaises(ValueError, privacy_report.plot_by_epochs, - (self.attack_results_no_metadata,), ['AUC']) + self.assertRaises( + ValueError, privacy_report.plot_by_epochs, + AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC']) def test_single_metric_plot_by_epochs(self): fig = privacy_report.plot_by_epochs( - (self.results_epoch_10, self.results_epoch_15), ['AUC']) + AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), + ['AUC']) # extract data from figure. auc_data = fig.gca().lines[0].get_data() # X axis lists epoch values @@ -97,7 +100,7 @@ class PrivacyReportTest(absltest.TestCase): def test_multiple_metrics_plot_by_epochs(self): fig = privacy_report.plot_by_epochs( - (self.results_epoch_10, self.results_epoch_15), + AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), ['AUC', 'Attacker advantage']) # extract data from figure. auc_data = fig.axes[0].lines[0].get_data() @@ -113,8 +116,9 @@ class PrivacyReportTest(absltest.TestCase): def test_multiple_metrics_plot_by_epochs_multiple_models(self): fig = privacy_report.plot_by_epochs( - (self.results_epoch_10, self.results_epoch_15, - self.results_epoch_15_model_2), ['AUC', 'Attacker advantage']) + AttackResultsCollection((self.results_epoch_10, self.results_epoch_15, + self.results_epoch_15_model_2)), + ['AUC', 'Attacker advantage']) # extract data from figure. # extract data from figure. auc_data_model_1 = fig.axes[0].lines[0].get_data() @@ -136,13 +140,14 @@ class PrivacyReportTest(absltest.TestCase): def test_plot_privacy_vs_accuracy_single_model_no_metadata(self): # Raise error if metadata is missing - self.assertRaises(ValueError, - privacy_report.plot_privacy_vs_accuracy_single_model, - (self.attack_results_no_metadata,), ['AUC']) + self.assertRaises( + ValueError, privacy_report.plot_privacy_vs_accuracy_single_model, + AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC']) def test_single_metric_plot_privacy_vs_accuracy_single_model(self): fig = privacy_report.plot_privacy_vs_accuracy_single_model( - (self.results_epoch_10, self.results_epoch_15), ['AUC']) + AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), + ['AUC']) # extract data from figure. auc_data = fig.gca().lines[0].get_data() # X axis lists epoch values @@ -154,7 +159,7 @@ class PrivacyReportTest(absltest.TestCase): def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self): fig = privacy_report.plot_privacy_vs_accuracy_single_model( - (self.results_epoch_10, self.results_epoch_15), + AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), ['AUC', 'Attacker advantage']) # extract data from figure. auc_data = fig.axes[0].lines[0].get_data() @@ -170,8 +175,9 @@ class PrivacyReportTest(absltest.TestCase): def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self): fig = privacy_report.plot_privacy_vs_accuracy_single_model( - (self.results_epoch_10, self.results_epoch_15, - self.results_epoch_15_model_2), ['AUC', 'Attacker advantage']) + AttackResultsCollection((self.results_epoch_10, self.results_epoch_15, + self.results_epoch_15_model_2)), + ['AUC', 'Attacker advantage']) # extract data from figure. auc_data_model_1 = fig.axes[0].lines[0].get_data() auc_data_model_2 = fig.axes[0].lines[1].get_data()