diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py index dbdc49d..ed801ca 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/plotting.py @@ -57,6 +57,7 @@ def plot_curve_with_area(x: Iterable[float], plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}') plt.xlabel(xlabel) plt.ylabel(ylabel) + plt.gca().set_aspect('equal', adjustable='box') plt.legend() return fig @@ -80,7 +81,7 @@ def plot_histograms(train: Iterable[float], return fig -def plot_roc_curve(roc_curve) -> plt.Figure: +def plot_roc_curve(roc_curve, plot_func=plot_curve_with_area) -> plt.Figure: """Plot the ROC curve and the area under the curve.""" - return plot_curve_with_area( + return plot_func( roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')