forked from 626_privacy/tensorflow_privacy
Multiple small changes to the TF Privacy Report:
- Fix the legend to the bottom right - Manually set the size of the plot figure. - Fix a typo in the subplot title. PiperOrigin-RevId: 337064528
This commit is contained in:
parent
a8aa0d5d96
commit
d1a8a6cfda
6 changed files with 23 additions and 19 deletions
|
@ -20,7 +20,7 @@ the model are used (e.g., losses, logits, predictions). Neither model internals
|
|||
|
||||
### Codelab
|
||||
|
||||
The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb).
|
||||
The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelabs/codelab.ipynb).
|
||||
This trains a simple image classification model and tests it against a series
|
||||
of membership inference attacks.
|
||||
|
||||
|
|
|
@ -609,8 +609,8 @@ def _get_attack_results_filename(attack_results: AttackResults, index: int):
|
|||
"""Creates a filename for a specific set of AttackResults."""
|
||||
metadata = attack_results.privacy_report_metadata
|
||||
if metadata is not None:
|
||||
return '%s_%s_%s.pickle' % (metadata.model_variant_label,
|
||||
metadata.epoch_num, index)
|
||||
return '%s_%s_epoch_%s.pickle' % (metadata.model_variant_label, index,
|
||||
metadata.epoch_num)
|
||||
return '%s.pickle' % index
|
||||
|
||||
|
||||
|
|
|
@ -169,7 +169,7 @@ def main(unused_argv):
|
|||
epoch_figure = privacy_report.plot_by_epochs(
|
||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||
epoch_figure.show()
|
||||
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
|
||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||
privacy_utility_figure.show()
|
||||
|
||||
|
|
|
@ -32,10 +32,11 @@ TRAIN_ACCURACY_STR = 'Train accuracy'
|
|||
|
||||
def plot_by_epochs(results: AttackResultsCollection,
|
||||
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
||||
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
|
||||
"""Plots privacy vulnerabilities vs epoch numbers.
|
||||
|
||||
In case multiple privacy metrics are specified, the plot will feature
|
||||
multiple subplots (one subplot per metrics).
|
||||
multiple subplots (one subplot per metrics). Multiple model variants
|
||||
are supported.
|
||||
Args:
|
||||
results: AttackResults for the plot
|
||||
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
||||
|
@ -54,12 +55,13 @@ def plot_by_epochs(results: AttackResultsCollection,
|
|||
privacy_metrics=privacy_metrics)
|
||||
|
||||
|
||||
def plot_privacy_vs_accuracy_single_model(
|
||||
results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]):
|
||||
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
|
||||
def plot_privacy_vs_accuracy(results: AttackResultsCollection,
|
||||
privacy_metrics: Iterable[PrivacyMetric]):
|
||||
"""Plots privacy vulnerabilities vs accuracy plots.
|
||||
|
||||
In case multiple privacy metrics are specified, the plot will feature
|
||||
multiple subplots (one subplot per metrics).
|
||||
multiple subplots (one subplot per metrics). Multiple model variants
|
||||
are supported.
|
||||
Args:
|
||||
results: AttackResults for the plot
|
||||
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
||||
|
@ -106,7 +108,8 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
|||
figure_title: str,
|
||||
privacy_metrics: Iterable[PrivacyMetric]):
|
||||
"""Create one subplot per privacy metric for a specified x_axis_metric."""
|
||||
fig, axes = plt.subplots(1, len(privacy_metrics))
|
||||
fig, axes = plt.subplots(
|
||||
1, len(privacy_metrics), figsize=(5 * len(privacy_metrics), 5))
|
||||
# Set a title for the entire group of subplots.
|
||||
fig.suptitle(figure_title)
|
||||
if len(privacy_metrics) == 1:
|
||||
|
@ -116,11 +119,12 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
|||
for legend_label in legend_labels:
|
||||
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
|
||||
== legend_label]
|
||||
axes[i].plot(single_label_results[x_axis_metric],
|
||||
single_label_results[str(privacy_metric)])
|
||||
axes[i].legend(legend_labels)
|
||||
sorted_label_results = single_label_results.sort_values(x_axis_metric)
|
||||
axes[i].plot(sorted_label_results[x_axis_metric],
|
||||
sorted_label_results[str(privacy_metric)])
|
||||
axes[i].legend(legend_labels, loc='lower right')
|
||||
axes[i].set_xlabel(x_axis_metric)
|
||||
axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR)
|
||||
axes[i].set_title('%s for %s' % (privacy_metric, ENTIRE_DATASET_SLICE_STR))
|
||||
|
||||
return fig
|
||||
|
||||
|
|
|
@ -141,11 +141,11 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
||||
# Raise error if metadata is missing
|
||||
self.assertRaises(
|
||||
ValueError, privacy_report.plot_privacy_vs_accuracy_single_model,
|
||||
ValueError, privacy_report.plot_privacy_vs_accuracy,
|
||||
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
|
||||
|
||||
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
|
||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
fig = privacy_report.plot_privacy_vs_accuracy(
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||
['AUC'])
|
||||
# extract data from figure.
|
||||
|
@ -158,7 +158,7 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||
|
||||
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
|
||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
fig = privacy_report.plot_privacy_vs_accuracy(
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||
['AUC', 'Attacker advantage'])
|
||||
# extract data from figure.
|
||||
|
@ -174,7 +174,7 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||
|
||||
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
|
||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
fig = privacy_report.plot_privacy_vs_accuracy(
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
|
||||
self.results_epoch_15_model_2)),
|
||||
['AUC', 'Attacker advantage'])
|
||||
|
|
Loading…
Reference in a new issue