forked from 626_privacy/tensorflow_privacy
Multiple small changes to the TF Privacy Report:
- Fix the legend to the bottom right - Manually set the size of the plot figure. - Fix a typo in the subplot title. PiperOrigin-RevId: 337064528
This commit is contained in:
parent
a8aa0d5d96
commit
d1a8a6cfda
6 changed files with 23 additions and 19 deletions
|
@ -20,7 +20,7 @@ the model are used (e.g., losses, logits, predictions). Neither model internals
|
||||||
|
|
||||||
### Codelab
|
### Codelab
|
||||||
|
|
||||||
The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb).
|
The easiest way to get started is to go through [the introductory codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelabs/codelab.ipynb).
|
||||||
This trains a simple image classification model and tests it against a series
|
This trains a simple image classification model and tests it against a series
|
||||||
of membership inference attacks.
|
of membership inference attacks.
|
||||||
|
|
||||||
|
|
|
@ -609,8 +609,8 @@ def _get_attack_results_filename(attack_results: AttackResults, index: int):
|
||||||
"""Creates a filename for a specific set of AttackResults."""
|
"""Creates a filename for a specific set of AttackResults."""
|
||||||
metadata = attack_results.privacy_report_metadata
|
metadata = attack_results.privacy_report_metadata
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
return '%s_%s_%s.pickle' % (metadata.model_variant_label,
|
return '%s_%s_epoch_%s.pickle' % (metadata.model_variant_label, index,
|
||||||
metadata.epoch_num, index)
|
metadata.epoch_num)
|
||||||
return '%s.pickle' % index
|
return '%s.pickle' % index
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -169,7 +169,7 @@ def main(unused_argv):
|
||||||
epoch_figure = privacy_report.plot_by_epochs(
|
epoch_figure = privacy_report.plot_by_epochs(
|
||||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||||
epoch_figure.show()
|
epoch_figure.show()
|
||||||
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
|
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
|
||||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||||
privacy_utility_figure.show()
|
privacy_utility_figure.show()
|
||||||
|
|
||||||
|
|
|
@ -32,10 +32,11 @@ TRAIN_ACCURACY_STR = 'Train accuracy'
|
||||||
|
|
||||||
def plot_by_epochs(results: AttackResultsCollection,
|
def plot_by_epochs(results: AttackResultsCollection,
|
||||||
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
||||||
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
|
"""Plots privacy vulnerabilities vs epoch numbers.
|
||||||
|
|
||||||
In case multiple privacy metrics are specified, the plot will feature
|
In case multiple privacy metrics are specified, the plot will feature
|
||||||
multiple subplots (one subplot per metrics).
|
multiple subplots (one subplot per metrics). Multiple model variants
|
||||||
|
are supported.
|
||||||
Args:
|
Args:
|
||||||
results: AttackResults for the plot
|
results: AttackResults for the plot
|
||||||
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
||||||
|
@ -54,12 +55,13 @@ def plot_by_epochs(results: AttackResultsCollection,
|
||||||
privacy_metrics=privacy_metrics)
|
privacy_metrics=privacy_metrics)
|
||||||
|
|
||||||
|
|
||||||
def plot_privacy_vs_accuracy_single_model(
|
def plot_privacy_vs_accuracy(results: AttackResultsCollection,
|
||||||
results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]):
|
privacy_metrics: Iterable[PrivacyMetric]):
|
||||||
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
|
"""Plots privacy vulnerabilities vs accuracy plots.
|
||||||
|
|
||||||
In case multiple privacy metrics are specified, the plot will feature
|
In case multiple privacy metrics are specified, the plot will feature
|
||||||
multiple subplots (one subplot per metrics).
|
multiple subplots (one subplot per metrics). Multiple model variants
|
||||||
|
are supported.
|
||||||
Args:
|
Args:
|
||||||
results: AttackResults for the plot
|
results: AttackResults for the plot
|
||||||
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
||||||
|
@ -106,7 +108,8 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
||||||
figure_title: str,
|
figure_title: str,
|
||||||
privacy_metrics: Iterable[PrivacyMetric]):
|
privacy_metrics: Iterable[PrivacyMetric]):
|
||||||
"""Create one subplot per privacy metric for a specified x_axis_metric."""
|
"""Create one subplot per privacy metric for a specified x_axis_metric."""
|
||||||
fig, axes = plt.subplots(1, len(privacy_metrics))
|
fig, axes = plt.subplots(
|
||||||
|
1, len(privacy_metrics), figsize=(5 * len(privacy_metrics), 5))
|
||||||
# Set a title for the entire group of subplots.
|
# Set a title for the entire group of subplots.
|
||||||
fig.suptitle(figure_title)
|
fig.suptitle(figure_title)
|
||||||
if len(privacy_metrics) == 1:
|
if len(privacy_metrics) == 1:
|
||||||
|
@ -116,11 +119,12 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
||||||
for legend_label in legend_labels:
|
for legend_label in legend_labels:
|
||||||
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
|
single_label_results = all_results_df.loc[all_results_df[LEGEND_LABEL_STR]
|
||||||
== legend_label]
|
== legend_label]
|
||||||
axes[i].plot(single_label_results[x_axis_metric],
|
sorted_label_results = single_label_results.sort_values(x_axis_metric)
|
||||||
single_label_results[str(privacy_metric)])
|
axes[i].plot(sorted_label_results[x_axis_metric],
|
||||||
axes[i].legend(legend_labels)
|
sorted_label_results[str(privacy_metric)])
|
||||||
|
axes[i].legend(legend_labels, loc='lower right')
|
||||||
axes[i].set_xlabel(x_axis_metric)
|
axes[i].set_xlabel(x_axis_metric)
|
||||||
axes[i].set_title('%s for Entire dataset' % ENTIRE_DATASET_SLICE_STR)
|
axes[i].set_title('%s for %s' % (privacy_metric, ENTIRE_DATASET_SLICE_STR))
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
|
@ -141,11 +141,11 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
||||||
# Raise error if metadata is missing
|
# Raise error if metadata is missing
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError, privacy_report.plot_privacy_vs_accuracy_single_model,
|
ValueError, privacy_report.plot_privacy_vs_accuracy,
|
||||||
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
|
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
|
||||||
|
|
||||||
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
|
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
|
||||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
fig = privacy_report.plot_privacy_vs_accuracy(
|
||||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||||
['AUC'])
|
['AUC'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
|
@ -158,7 +158,7 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||||
|
|
||||||
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
|
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
|
||||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
fig = privacy_report.plot_privacy_vs_accuracy(
|
||||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||||
['AUC', 'Attacker advantage'])
|
['AUC', 'Attacker advantage'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
|
@ -174,7 +174,7 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||||
|
|
||||||
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
|
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
|
||||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
fig = privacy_report.plot_privacy_vs_accuracy(
|
||||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
|
||||||
self.results_epoch_15_model_2)),
|
self.results_epoch_15_model_2)),
|
||||||
['AUC', 'Attacker advantage'])
|
['AUC', 'Attacker advantage'])
|
||||||
|
|
Loading…
Reference in a new issue