From 8d147bc9d7193dd45aecd5f434f0005088b0fcc7 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Wed, 22 Dec 2021 11:08:17 -0800 Subject: [PATCH] For MIA plotting, allow customized plotting function and set equal x and y aspects. PiperOrigin-RevId: 417852309 --- .../privacy_tests/membership_inference_attack/plotting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py index dbdc49d..ed801ca 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py @@ -57,6 +57,7 @@ def plot_curve_with_area(x: Iterable[float], plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}') plt.xlabel(xlabel) plt.ylabel(ylabel) + plt.gca().set_aspect('equal', adjustable='box') plt.legend() return fig @@ -80,7 +81,7 @@ def plot_histograms(train: Iterable[float], return fig -def plot_roc_curve(roc_curve) -> plt.Figure: +def plot_roc_curve(roc_curve, plot_func=plot_curve_with_area) -> plt.Figure: """Plot the ROC curve and the area under the curve.""" - return plot_curve_with_area( + return plot_func( roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')