Skip to content

Commit

Permalink
fix: raise warnings for invalid number of splits in test train data s…
Browse files Browse the repository at this point in the history
…plitting
  • Loading branch information
mdtanker committed Nov 19, 2024
1 parent ad14dc0 commit bcef273
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/invert4geom/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,18 @@ def split_test_train(
if method == "LeaveOneOut":
kfold = sklearn.model_selection.LeaveOneOut()
elif method == "KFold":
if n_splits > len(df):
msg = (
"n_splits must be less than or equal to the number of data points, "
"decreasing n_splits"
)
log.warning(msg)
n_splits = len(df)

if n_splits == 1:
msg = "n_splits must be greater than 1"
raise ValueError(msg)

if spacing or shape is None:
kfold = sklearn.model_selection.KFold(
n_splits=n_splits,
Expand Down

0 comments on commit bcef273

Please sign in to comment.