Skip to content

Commit

Permalink
test: simplify custom complexity test
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jun 16, 2024
1 parent 09617a6 commit 06a77ea
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,26 +179,25 @@ def test_multioutput_custom_operator_quiet_custom_complexity(self):

def test_custom_variable_complexity(self):
for case in (1, 2):
y = self.X[:, [0, 1]] ** 2
y = self.X[:, [0, 1]]
model = PySRRegressor(
binary_operators=["*", "+"],
binary_operators=["+"],
verbosity=0,
**self.default_test_kwargs,
early_stop_condition=f"stop_if(l, c) = l < 1e-8 && c <= {7 if case == 1 else 5}",
early_stop_condition=f"stop_if_{case}(l, c) = l < 1e-8 && c <= {3 if case == 1 else 2}",
)
if case == 1:
complexity_of_variables = [2, 3] + [
100 for _ in range(self.X.shape[1] - 2)
]
complexity_of_variables = [2, 3]
elif case == 2:
complexity_of_variables = 2
model.fit(self.X, y, complexity_of_variables=complexity_of_variables)
equations = model.equations_
self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
model.fit(
self.X[:, [0, 1]], y, complexity_of_variables=complexity_of_variables
)
self.assertLessEqual(model.get_best()[0]["loss"], 1e-8)
self.assertLessEqual(model.get_best()[1]["loss"], 1e-8)

self.assertEqual(model.get_best()[0]["complexity"], 5)
self.assertEqual(model.get_best()[1]["complexity"], 7 if case == 1 else 5)
self.assertEqual(model.get_best()[0]["complexity"], 2)
self.assertEqual(model.get_best()[1]["complexity"], 3 if case == 1 else 2)

def test_multioutput_weighted_with_callable_temp_equation(self):
X = self.X.copy()
Expand Down

0 comments on commit 06a77ea

Please sign in to comment.