Simple ROC curve plotting for membership inference attack results.

PiperOrigin-RevId: 325982344
This commit is contained in:
A. Unique TensorFlower 2020-08-11 02:28:48 -07:00
parent 99afaed68e
commit 680aaa4499
2 changed files with 17 additions and 0 deletions

View file

@ -19,6 +19,8 @@ This is using a toy model based on classifying four spacial clusters of data.
""" """
import os import os
import tempfile import tempfile
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from tensorflow import keras from tensorflow import keras
from tensorflow.keras import layers from tensorflow.keras import layers
@ -28,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting
def generate_random_cluster(center, scale, num_points): def generate_random_cluster(center, scale, num_points):
@ -147,3 +150,11 @@ print(attack_results.summary(by_slices=False))
print("Summary by slices: \n") print("Summary by slices: \n")
print(attack_results.summary(by_slices=True)) print(attack_results.summary(by_slices=True))
# Example of ROC curve plotting.
figure = plotting.plot_roc_curve(
attack_results.single_attack_results[0].roc_curve)
plt.show()
# For saving a figure into a file:
# plotting.save_plot(figure, <file_path>)

View file

@ -78,3 +78,9 @@ def plot_histograms(train: Iterable[float],
plt.ylabel('normalized counts (density)') plt.ylabel('normalized counts (density)')
plt.legend() plt.legend()
return fig return fig
def plot_roc_curve(roc_curve) -> plt.Figure:
"""Plot the ROC curve and the area under the curve."""
return plot_curve_with_area(
roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')