diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index 804a997..eb32c6d 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -19,6 +19,8 @@ This is using a toy model based on classifying four spacial clusters of data. """ import os import tempfile + +import matplotlib.pyplot as plt import numpy as np from tensorflow import keras from tensorflow.keras import layers @@ -28,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec +import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting def generate_random_cluster(center, scale, num_points): @@ -147,3 +150,11 @@ print(attack_results.summary(by_slices=False)) print("Summary by slices: \n") print(attack_results.summary(by_slices=True)) + +# Example of ROC curve plotting. +figure = plotting.plot_roc_curve( + attack_results.single_attack_results[0].roc_curve) +plt.show() + +# For saving a figure into a file: +# plotting.save_plot(figure, ) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/plotting.py b/tensorflow_privacy/privacy/membership_inference_attack/plotting.py index ad415c2..f3eab8f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/plotting.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/plotting.py @@ -78,3 +78,9 @@ def plot_histograms(train: Iterable[float], plt.ylabel('normalized counts (density)') plt.legend() return fig + + +def plot_roc_curve(roc_curve) -> plt.Figure: + """Plot the ROC curve and the area under the curve.""" + return plot_curve_with_area( + roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')