From d1a8a6cfda9dde33528bd526f7139ad9e9185155 Mon Sep 17 00:00:00 2001 From: David Marn Date: Wed, 14 Oct 2020 04:41:09 -0700 Subject: [PATCH] Multiple small changes to the TF Privacy Report: - Fix the legend to the bottom right - Manually set the size of the plot figure. - Fix a typo in the subplot title. PiperOrigin-RevId: 337064528 --- .../membership_inference_attack/README.md | 2 +- .../{ => codelabs}/codelab.ipynb | 0 .../data_structures.py | 4 +-- .../membership_inference_attack/example.py | 2 +- .../privacy_report.py | 26 +++++++++++-------- .../privacy_report_test.py | 8 +++--- 6 files changed, 23 insertions(+), 19 deletions(-) rename tensorflow_privacy/privacy/membership_inference_attack/{ => codelabs}/codelab.ipynb (100%) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/README.md b/tensorflow_privacy/privacy/membership_inference_attack/README.md index 27c95e7..88e2281 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/README.md +++ b/tensorflow_privacy/privacy/membership_inference_attack/README.md @@ -20,7 +20,7 @@ the model are used (e.g., losses, logits, predictions). Neither model internals ### Codelab -The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb). +The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelabs/codelab.ipynb). This trains a simple image classification model and tests it against a series of membership inference attacks. diff --git a/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb b/tensorflow_privacy/privacy/membership_inference_attack/codelabs/codelab.ipynb similarity index 100% rename from tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb rename to tensorflow_privacy/privacy/membership_inference_attack/codelabs/codelab.ipynb diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index efc6859..aea8e78 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -609,8 +609,8 @@ def _get_attack_results_filename(attack_results: AttackResults, index: int): """Creates a filename for a specific set of AttackResults.""" metadata = attack_results.privacy_report_metadata if metadata is not None: - return '%s_%s_%s.pickle' % (metadata.model_variant_label, - metadata.epoch_num, index) + return '%s_%s_epoch_%s.pickle' % (metadata.model_variant_label, index, + metadata.epoch_num) return '%s.pickle' % index diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index 1f2d2af..2c8ff70 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -169,7 +169,7 @@ def main(unused_argv): epoch_figure = privacy_report.plot_by_epochs( epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) epoch_figure.show() - privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model( + privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy( epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) privacy_utility_figure.show() diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py index b12baf1..c81f714 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py @@ -32,10 +32,11 @@ TRAIN_ACCURACY_STR = 'Train accuracy' def plot_by_epochs(results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure: - """Plots privacy vulnerabilities vs epoch numbers for a single model variant. + """Plots privacy vulnerabilities vs epoch numbers. In case multiple privacy metrics are specified, the plot will feature - multiple subplots (one subplot per metrics). + multiple subplots (one subplot per metrics). Multiple model variants + are supported. Args: results: AttackResults for the plot privacy_metrics: List of enumerated privacy metrics that should be plotted. @@ -54,12 +55,13 @@ def plot_by_epochs(results: AttackResultsCollection, privacy_metrics=privacy_metrics) -def plot_privacy_vs_accuracy_single_model( - results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]): - """Plots privacy vulnerabilities vs accuracy plots for a single model variant. +def plot_privacy_vs_accuracy(results: AttackResultsCollection, + privacy_metrics: Iterable[PrivacyMetric]): + """Plots privacy vulnerabilities vs accuracy plots. In case multiple privacy metrics are specified, the plot will feature - multiple subplots (one subplot per metrics). + multiple subplots (one subplot per metrics). Multiple model variants + are supported. Args: results: AttackResults for the plot privacy_metrics: List of enumerated privacy metrics that should be plotted. @@ -106,7 +108,8 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str, figure_title: str, privacy_metrics: Iterable[PrivacyMetric]): """Create one subplot per privacy metric for a specified x_axis_metric.""" - fig, axes = plt.subplots(1, len(privacy_metrics)) + fig, axes = plt.subplots( + 1, len(privacy_metrics), figsize=(5 * len(privacy_metrics), 5)) # Set a title for the entire group of subplots. fig.suptitle(figure_title) if len(privacy_metrics) == 1: @@ -116,11 +119,12 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str, for legend_label in legend_labels: single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR] == legend_label] - axes[i].plot(single_label_results[x_axis_metric], - single_label_results[str(privacy_metric)]) - axes[i].legend(legend_labels) + sorted_label_results = single_label_results.sort_values(x_axis_metric) + axes[i].plot(sorted_label_results[x_axis_metric], + sorted_label_results[str(privacy_metric)]) + axes[i].legend(legend_labels, loc='lower right') axes[i].set_xlabel(x_axis_metric) - axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR) + axes[i].set_title('%s for %s' % (privacy_metric, ENTIRE_DATASET_SLICE_STR)) return fig diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py index a38383c..8b3bdf2 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py @@ -141,11 +141,11 @@ class PrivacyReportTest(absltest.TestCase): def test_plot_privacy_vs_accuracy_single_model_no_metadata(self): # Raise error if metadata is missing self.assertRaises( - ValueError, privacy_report.plot_privacy_vs_accuracy_single_model, + ValueError, privacy_report.plot_privacy_vs_accuracy, AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC']) def test_single_metric_plot_privacy_vs_accuracy_single_model(self): - fig = privacy_report.plot_privacy_vs_accuracy_single_model( + fig = privacy_report.plot_privacy_vs_accuracy( AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), ['AUC']) # extract data from figure. @@ -158,7 +158,7 @@ class PrivacyReportTest(absltest.TestCase): self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis') def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self): - fig = privacy_report.plot_privacy_vs_accuracy_single_model( + fig = privacy_report.plot_privacy_vs_accuracy( AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), ['AUC', 'Attacker advantage']) # extract data from figure. @@ -174,7 +174,7 @@ class PrivacyReportTest(absltest.TestCase): self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis') def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self): - fig = privacy_report.plot_privacy_vs_accuracy_single_model( + fig = privacy_report.plot_privacy_vs_accuracy( AttackResultsCollection((self.results_epoch_10, self.results_epoch_15, self.results_epoch_15_model_2)), ['AUC', 'Attacker advantage'])