For MIA plotting, allow customized plotting function and set equal x and y aspects.

PiperOrigin-RevId: 417852309
This commit is contained in:
Shuang Song 2021-12-22 11:08:17 -08:00 committed by A. Unique TensorFlower
parent c6576f60c4
commit 8d147bc9d7

View file

@ -57,6 +57,7 @@ def plot_curve_with_area(x: Iterable[float],
plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.gca().set_aspect('equal', adjustable='box')
plt.legend()
return fig
@ -80,7 +81,7 @@ def plot_histograms(train: Iterable[float],
return fig
def plot_roc_curve(roc_curve) -> plt.Figure:
def plot_roc_curve(roc_curve, plot_func=plot_curve_with_area) -> plt.Figure:
"""Plot the ROC curve and the area under the curve."""
return plot_curve_with_area(
return plot_func(
roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')