Skip to content

Commit

Permalink
feat: return fig for plot_cv_scores
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Nov 19, 2024
1 parent 155f645 commit 70e4a0d
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def plot_cv_scores(
figsize: tuple[float, float] = (5, 3.5),
plot_title: str | None = None,
fname: str | None = None,
) -> None:
) -> typing.Any:
"""
plot a graph of cross-validation scores vs hyperparameter values
Expand All @@ -229,6 +229,10 @@ def plot_cv_scores(
title of figure, by default None
fname : str | None, optional
filename to save figure, by default None
Returns
-------
a matplotlib figure instance
"""

sns.set_theme()
Expand All @@ -238,33 +242,35 @@ def plot_cv_scores(

best = df.scores.argmin()

plt.figure(figsize=figsize)
fig, ax = plt.subplots(figsize=figsize)
if plot_title is not None:
plt.title(plot_title)
ax.set_title(plot_title)
else:
plt.title(f"{param_name} Cross-validation")
plt.plot(df.parameters, df.scores, marker="o")
plt.plot(
ax.set_title(f"{param_name} Cross-validation")
ax.plot(df.parameters, df.scores, marker="o")
ax.plot(
df.parameters.iloc[best],
df.scores.iloc[best],
"s",
markersize=10,
color=sns.color_palette()[3],
label="Minimum",
)
plt.legend(loc="best")
ax.legend(loc="best")
if logx:
plt.xscale("log")
ax.set_xscale("log")
if logy:
plt.yscale("log")
plt.xlabel(f"{param_name} value")
plt.ylabel("Root Mean Square Error")
ax.set_yscale("log")
ax.set_xlabel(f"{param_name} value")
ax.set_ylabel("Root Mean Square Error")

plt.tight_layout()

if fname is not None:
plt.savefig(fname)

return fig


def plot_convergence(
results: pd.DataFrame,
Expand Down

0 comments on commit 70e4a0d

Please sign in to comment.