Multiple small changes to the TF Privacy Report:

- Fix the legend to the bottom right
     - Manually set the size of the plot figure.
     - Fix a typo in the subplot title.

PiperOrigin-RevId: 337064528
This commit is contained in:
David Marn 2020-10-14 04:41:09 -07:00 committed by A. Unique TensorFlower
parent a8aa0d5d96
commit d1a8a6cfda
6 changed files with 23 additions and 19 deletions

View file

@ -20,7 +20,7 @@ the model are used (e.g., losses, logits, predictions). Neither model internals
### Codelab ### Codelab
The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb). The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelabs/codelab.ipynb).
This trains a simple image classification model and tests it against a series This trains a simple image classification model and tests it against a series
of membership inference attacks. of membership inference attacks.

View file

@ -609,8 +609,8 @@ def _get_attack_results_filename(attack_results: AttackResults, index: int):
"""Creates a filename for a specific set of AttackResults.""" """Creates a filename for a specific set of AttackResults."""
metadata = attack_results.privacy_report_metadata metadata = attack_results.privacy_report_metadata
if metadata is not None: if metadata is not None:
return '%s_%s_%s.pickle' % (metadata.model_variant_label, return '%s_%s_epoch_%s.pickle' % (metadata.model_variant_label, index,
metadata.epoch_num, index) metadata.epoch_num)
return '%s.pickle' % index return '%s.pickle' % index

View file

@ -169,7 +169,7 @@ def main(unused_argv):
epoch_figure = privacy_report.plot_by_epochs( epoch_figure = privacy_report.plot_by_epochs(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
epoch_figure.show() epoch_figure.show()
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model( privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
privacy_utility_figure.show() privacy_utility_figure.show()

View file

@ -32,10 +32,11 @@ TRAIN_ACCURACY_STR = 'Train accuracy'
def plot_by_epochs(results: AttackResultsCollection, def plot_by_epochs(results: AttackResultsCollection,
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure: privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant. """Plots privacy vulnerabilities vs epoch numbers.
In case multiple privacy metrics are specified, the plot will feature In case multiple privacy metrics are specified, the plot will feature
multiple subplots (one subplot per metrics). multiple subplots (one subplot per metrics). Multiple model variants
are supported.
Args: Args:
results: AttackResults for the plot results: AttackResults for the plot
privacy_metrics: List of enumerated privacy metrics that should be plotted. privacy_metrics: List of enumerated privacy metrics that should be plotted.
@ -54,12 +55,13 @@ def plot_by_epochs(results: AttackResultsCollection,
privacy_metrics=privacy_metrics) privacy_metrics=privacy_metrics)
def plot_privacy_vs_accuracy_single_model( def plot_privacy_vs_accuracy(results: AttackResultsCollection,
results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]): privacy_metrics: Iterable[PrivacyMetric]):
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant. """Plots privacy vulnerabilities vs accuracy plots.
In case multiple privacy metrics are specified, the plot will feature In case multiple privacy metrics are specified, the plot will feature
multiple subplots (one subplot per metrics). multiple subplots (one subplot per metrics). Multiple model variants
are supported.
Args: Args:
results: AttackResults for the plot results: AttackResults for the plot
privacy_metrics: List of enumerated privacy metrics that should be plotted. privacy_metrics: List of enumerated privacy metrics that should be plotted.
@ -106,7 +108,8 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
figure_title: str, figure_title: str,
privacy_metrics: Iterable[PrivacyMetric]): privacy_metrics: Iterable[PrivacyMetric]):
"""Create one subplot per privacy metric for a specified x_axis_metric.""" """Create one subplot per privacy metric for a specified x_axis_metric."""
fig, axes = plt.subplots(1, len(privacy_metrics)) fig, axes = plt.subplots(
1, len(privacy_metrics), figsize=(5 * len(privacy_metrics), 5))
# Set a title for the entire group of subplots. # Set a title for the entire group of subplots.
fig.suptitle(figure_title) fig.suptitle(figure_title)
if len(privacy_metrics) == 1: if len(privacy_metrics) == 1:
@ -116,11 +119,12 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
for legend_label in legend_labels: for legend_label in legend_labels:
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR] single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
== legend_label] == legend_label]
axes[i].plot(single_label_results[x_axis_metric], sorted_label_results = single_label_results.sort_values(x_axis_metric)
single_label_results[str(privacy_metric)]) axes[i].plot(sorted_label_results[x_axis_metric],
axes[i].legend(legend_labels) sorted_label_results[str(privacy_metric)])
axes[i].legend(legend_labels, loc='lower right')
axes[i].set_xlabel(x_axis_metric) axes[i].set_xlabel(x_axis_metric)
axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR) axes[i].set_title('%s for %s' % (privacy_metric, ENTIRE_DATASET_SLICE_STR))
return fig return fig

View file

@ -141,11 +141,11 @@ class PrivacyReportTest(absltest.TestCase):
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self): def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
# Raise error if metadata is missing # Raise error if metadata is missing
self.assertRaises( self.assertRaises(
ValueError, privacy_report.plot_privacy_vs_accuracy_single_model, ValueError, privacy_report.plot_privacy_vs_accuracy,
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC']) AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
def test_single_metric_plot_privacy_vs_accuracy_single_model(self): def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model( fig = privacy_report.plot_privacy_vs_accuracy(
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC']) ['AUC'])
# extract data from figure. # extract data from figure.
@ -158,7 +158,7 @@ class PrivacyReportTest(absltest.TestCase):
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis') self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self): def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model( fig = privacy_report.plot_privacy_vs_accuracy(
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC', 'Attacker advantage']) ['AUC', 'Attacker advantage'])
# extract data from figure. # extract data from figure.
@ -174,7 +174,7 @@ class PrivacyReportTest(absltest.TestCase):
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis') self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self): def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model( fig = privacy_report.plot_privacy_vs_accuracy(
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15, AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2)), self.results_epoch_15_model_2)),
['AUC', 'Attacker advantage']) ['AUC', 'Attacker advantage'])