Skip to content

Commit

Permalink
refactor: clean up use of __init__ in keyword suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 5, 2024
1 parent ea7ced4 commit 9207687
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
21 changes: 11 additions & 10 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def __init__(
FutureWarning,
)
else:
suggested_keywords = self._suggest_keywords(k)
suggested_keywords = self._suggest_keywords(PySRRegressor, k)
err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
if len(suggested_keywords) > 0:
err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
Expand Down Expand Up @@ -1995,15 +1995,6 @@ def fit(

return self

def _suggest_keywords(self, k: str) -> List[str]:
valid_keywords = [
param
for param in inspect.signature(self.__init__).parameters
if param not in ["self", "kwargs"]
]
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
return suggestions

def refresh(self, checkpoint_file=None) -> None:
"""
Update self.equations_ with any new options passed.
Expand Down Expand Up @@ -2455,6 +2446,16 @@ def latex_table(
return with_preamble(table_string)


def _suggest_keywords(cls, k: str) -> List[str]:
valid_keywords = [
param
for param in inspect.signature(cls.__init__).parameters
if param not in ["self", "kwargs"]
]
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
return suggestions


def idx_model_selection(equations: pd.DataFrame, model_selection: str):
"""Select an expression and return its index."""
if model_selection == "accuracy":
Expand Down
12 changes: 9 additions & 3 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from ..export_latex import sympy2latex
from ..feature_selection import _handle_feature_selection, run_feature_selection
from ..julia_helpers import init_julia
from ..sr import _check_assertions, _process_constraints, idx_model_selection
from ..sr import (
_check_assertions,
_process_constraints,
_suggest_keywords,
idx_model_selection,
)
from ..utils import _csv_filename_to_pkl_filename
from .params import (
DEFAULT_NCYCLES,
Expand Down Expand Up @@ -805,9 +810,10 @@ def test_bad_kwargs(self):
print("Failed", opt["kwargs"])

def test_suggest_keywords(self):
model = PySRRegressor()
# Easy
self.assertEqual(model._suggest_keywords("loss_function"), ["loss_function"])
self.assertEqual(
_suggest_keywords(PySRRegressor, "loss_function"), ["loss_function"]
)

# More complex, and with error
with self.assertRaises(TypeError) as cm:
Expand Down

0 comments on commit 9207687

Please sign in to comment.