For MIA plotting, allow customized plotting function and set equal x and y aspects.
PiperOrigin-RevId: 417852309
This commit is contained in:
parent
c6576f60c4
commit
8d147bc9d7
1 changed files with 3 additions and 2 deletions
|
@ -57,6 +57,7 @@ def plot_curve_with_area(x: Iterable[float],
|
||||||
plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}')
|
plt.plot(x, y, lw=2, label=f'AUC: {metrics.auc(x, y):.3f}')
|
||||||
plt.xlabel(xlabel)
|
plt.xlabel(xlabel)
|
||||||
plt.ylabel(ylabel)
|
plt.ylabel(ylabel)
|
||||||
|
plt.gca().set_aspect('equal', adjustable='box')
|
||||||
plt.legend()
|
plt.legend()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
@ -80,7 +81,7 @@ def plot_histograms(train: Iterable[float],
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plot_roc_curve(roc_curve) -> plt.Figure:
|
def plot_roc_curve(roc_curve, plot_func=plot_curve_with_area) -> plt.Figure:
|
||||||
"""Plot the ROC curve and the area under the curve."""
|
"""Plot the ROC curve and the area under the curve."""
|
||||||
return plot_curve_with_area(
|
return plot_func(
|
||||||
roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')
|
roc_curve.fpr, roc_curve.tpr, xlabel='FPR', ylabel='TPR')
|
||||||
|
|
Loading…
Reference in a new issue