forked from 626_privacy/tensorflow_privacy
Simple ROC curve plotting for membership inference attack results.
PiperOrigin-RevId: 325982344
This commit is contained in:
parent
99afaed68e
commit
680aaa4499
2 changed files with 17 additions and 0 deletions
|
@ -19,6 +19,8 @@ This is using a toy model based on classifying four spacial clusters of data.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
|
@ -28,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting
|
||||||
|
|
||||||
|
|
||||||
def generate_random_cluster(center, scale, num_points):
|
def generate_random_cluster(center, scale, num_points):
|
||||||
|
@ -147,3 +150,11 @@ print(attack_results.summary(by_slices=False))
|
||||||
|
|
||||||
print("Summary by slices: \n")
|
print("Summary by slices: \n")
|
||||||
print(attack_results.summary(by_slices=True))
|
print(attack_results.summary(by_slices=True))
|
||||||
|
|
||||||
|
# Example of ROC curve plotting.
|
||||||
|
figure = plotting.plot_roc_curve(
|
||||||
|
attack_results.single_attack_results[0].roc_curve)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# For saving a figure into a file:
|
||||||
|
# plotting.save_plot(figure, <file_path>)
|
||||||
|
|
|
@ -78,3 +78,9 @@ def plot_histograms(train: Iterable[float],
|
||||||
plt.ylabel('normalized counts (density)')
|
plt.ylabel('normalized counts (density)')
|
||||||
plt.legend()
|
plt.legend()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curve(roc_curve) -> plt.Figure:
|
||||||
|
"""Plot the ROC curve and the area under the curve."""
|
||||||
|
return plot_curve_with_area(
|
||||||
|
roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')
|
||||||
|
|
Loading…
Reference in a new issue