Skip to content

Commit

Permalink
refactor: more robust parsing of operators
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 5, 2024
1 parent 88d93a1 commit 9f3b918
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def jl_array(x, dtype=None):
return jl_convert(jl.Array[dtype], x)


def jl_is_function(f):
return jl.seval("op -> op isa Function")(f)
def jl_is_function(f) -> bool:
return cast(bool, jl.seval("op -> op isa Function")(f))


def jl_serialize(obj: Any) -> NDArray[np.uint8]:
Expand Down
21 changes: 18 additions & 3 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -44,6 +44,7 @@
_load_cluster_manager,
jl_array,
jl_deserialize,
jl_is_function,
jl_serialize,
)
from .julia_import import SymbolicRegression, jl
Expand Down Expand Up @@ -1695,11 +1696,25 @@ def _run(
optimize=self.weight_optimize,
)

jl_binary_operators: list[Any] = []
jl_unary_operators: list[Any] = []
for input_list, output_list, name in [
(binary_operators, jl_binary_operators, "binary"),
(unary_operators, jl_unary_operators, "unary"),
]:
for op in input_list:
jl_op = jl.seval(op)
if not jl_is_function(jl_op):
raise ValueError(
f"When building `{name}_operators`, `'{op}'` did not return a Julia function"
)
output_list.append(jl_op)

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
options = SymbolicRegression.Options(
binary_operators=jl.seval(str(binary_operators).replace("'", "")),
unary_operators=jl.seval(str(unary_operators).replace("'", "")),
binary_operators=jl_array(jl_binary_operators, dtype=jl.Function),
unary_operators=jl_array(jl_unary_operators, dtype=jl.Function),
bin_constraints=jl_array(bin_constraints),
una_constraints=jl_array(una_constraints),
complexity_of_operators=complexity_of_operators,
Expand Down
10 changes: 10 additions & 0 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,16 @@ def test_load_model_simple(self):
)
np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))

def test_jl_function_error(self):
# TODO: Move this to better class
with self.assertRaises(ValueError) as cm:
PySRRegressor(unary_operators=["1"]).fit([[1]], [1])

self.assertIn(
"When building `unary_operators`, `'1'` did not return a Julia function",
str(cm.exception),
)


def manually_create_model(equations, feature_names=None):
if feature_names is None:
Expand Down

0 comments on commit 9f3b918

Please sign in to comment.