From 680aaa44990e9f78834e733f1c6f3db8b1dee696 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Aug 2020 02:28:48 -0700 Subject: [PATCH] Simple ROC curve plotting for membership inference attack results. PiperOrigin-RevId: 325982344 --- .../privacy/membership_inference_attack/example.py | 11 +++++++++++ .../privacy/membership_inference_attack/plotting.py | 6 ++++++ 2 files changed, 17 insertions(+) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index 804a997..eb32c6d 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -19,6 +19,8 @@ This is using a toy model based on classifying four spacial clusters of data. """ import os import tempfile + +import matplotlib.pyplot as plt import numpy as np from tensorflow import keras from tensorflow.keras import layers @@ -28,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec +import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting def generate_random_cluster(center, scale, num_points): @@ -147,3 +150,11 @@ print(attack_results.summary(by_slices=False)) print("Summary by slices: \n") print(attack_results.summary(by_slices=True)) + +# Example of ROC curve plotting. +figure = plotting.plot_roc_curve( + attack_results.single_attack_results[0].roc_curve) +plt.show() + +# For saving a figure into a file: +# plotting.save_plot(figure, ) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/plotting.py b/tensorflow_privacy/privacy/membership_inference_attack/plotting.py index ad415c2..f3eab8f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/plotting.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/plotting.py @@ -78,3 +78,9 @@ def plot_histograms(train: Iterable[float], plt.ylabel('normalized counts (density)') plt.legend() return fig + + +def plot_roc_curve(roc_curve) -> plt.Figure: + """Plot the ROC curve and the area under the curve.""" + return plot_curve_with_area( + roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')