forked from 626_privacy/tensorflow_privacy
Simple ROC curve plotting for membership inference attack results.
PiperOrigin-RevId: 325982344
This commit is contained in:
parent
99afaed68e
commit
680aaa4499
2 changed files with 17 additions and 0 deletions
|
@ -19,6 +19,8 @@ This is using a toy model based on classifying four spacial clusters of data.
|
|||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
|
@ -28,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting
|
||||
|
||||
|
||||
def generate_random_cluster(center, scale, num_points):
|
||||
|
@ -147,3 +150,11 @@ print(attack_results.summary(by_slices=False))
|
|||
|
||||
print("Summary by slices: \n")
|
||||
print(attack_results.summary(by_slices=True))
|
||||
|
||||
# Example of ROC curve plotting.
|
||||
figure = plotting.plot_roc_curve(
|
||||
attack_results.single_attack_results[0].roc_curve)
|
||||
plt.show()
|
||||
|
||||
# For saving a figure into a file:
|
||||
# plotting.save_plot(figure, <file_path>)
|
||||
|
|
|
@ -78,3 +78,9 @@ def plot_histograms(train: Iterable[float],
|
|||
plt.ylabel('normalized counts (density)')
|
||||
plt.legend()
|
||||
return fig
|
||||
|
||||
|
||||
def plot_roc_curve(roc_curve) -> plt.Figure:
|
||||
"""Plot the ROC curve and the area under the curve."""
|
||||
return plot_curve_with_area(
|
||||
roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')
|
||||
|
|
Loading…
Reference in a new issue