Skip to content

Commit

Permalink
test: list-like variable complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jun 16, 2024
1 parent 71cda07 commit 171306f
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,26 @@ def test_multioutput_custom_operator_quiet_custom_complexity(self):
self.assertLessEqual(mse1, 1e-4)
self.assertLessEqual(mse2, 1e-4)

def test_custom_variable_complexity(self):
y = self.X[:, [0, 1]] ** 2
model = PySRRegressor(
binary_operators=["*", "+"],
verbosity=0,
**self.default_test_kwargs,
early_stop_condition="stop_if(l, c) = l < 1e-4 && c <= 7",
)
model.fit(
self.X,
y,
complexity_of_variables=[2, 3] + [100 for _ in range(self.X.shape[1] - 2)],
)
equations = model.equations_
self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)

self.assertEqual(model.get_best()[0]["complexity"], 5)
self.assertEqual(model.get_best()[1]["complexity"], 7)

def test_multioutput_weighted_with_callable_temp_equation(self):
X = self.X.copy()
y = X[:, [0, 1]] ** 2
Expand Down Expand Up @@ -1053,8 +1073,14 @@ def test_unit_checks(self):
"""This just checks the number of units passed"""
use_custom_variable_names = False
variable_names = None
complexity_of_variables = 1
weights = None
args = (use_custom_variable_names, variable_names, weights)
args = (
use_custom_variable_names,
variable_names,
complexity_of_variables,
weights,
)
valid_units = [
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
Expand Down

0 comments on commit 171306f

Please sign in to comment.