diff --git a/pyproject.toml b/pyproject.toml index 3e90468..727493b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.2.2" +version = "5.2.3" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 2f61cd2..3685b4f 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -222,9 +222,11 @@ def __init__( ): if dataframe.empty: raise ValueError("Training data is inadequate.") + dataframe["seq_len"] = dataframe["tensor"].map(len) + dataframe = dataframe[dataframe["seq_len"] <= max_seq_len] if sort_by_length: - dataframe = dataframe.sort_values(by=["i"]) - dataframe = dataframe[dataframe["tensor"].map(len) <= max_seq_len] + dataframe = dataframe.sort_values(by="seq_len") + del dataframe["seq_len"] self.x_train = pad_sequence( dataframe["tensor"].to_list(), batch_first=True, padding_value=0 )