Skip to content

Commit

Permalink
Merge pull request #658 from MilesCranmer/fix-number-symbol
Browse files Browse the repository at this point in the history
BREAKING: Disable automatic sympy simplification
  • Loading branch information
MilesCranmer authored Jun 22, 2024
2 parents 06ca0e3 + 36e5dde commit 7fc7b82
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 16 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This builds a dockerfile containing a working copy of PySR
# with all pre-requisites installed.

ARG JLVERSION=1.9.4
ARG PYVERSION=3.11.6
ARG JLVERSION=1.10.4
ARG PYVERSION=3.12.2
ARG BASE_IMAGE=bullseye

FROM julia:${JLVERSION}-${BASE_IMAGE} AS jl
Expand Down
4 changes: 3 additions & 1 deletion pysr/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
if issubclass(expr.func, sympy.Float):
parameters.append(float(expr))
return f"parameters[{len(parameters) - 1}]"
elif issubclass(expr.func, sympy.Rational):
elif issubclass(expr.func, sympy.Rational) or issubclass(
expr.func, sympy.NumberSymbol
):
return f"{float(expr)}"
elif issubclass(expr.func, sympy.Integer):
return f"{int(expr)}"
Expand Down
7 changes: 6 additions & 1 deletion pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def pysr2sympy(
**sympy_mappings,
}

return sympify(equation, locals=local_sympy_mappings)
try:
return sympify(equation, locals=local_sympy_mappings, evaluate=False)
except TypeError as e:
if "got an unexpected keyword argument 'evaluate'" in str(e):
return sympify(equation, locals=local_sympy_mappings)
raise TypeError(f"Error processing equation '{equation}'") from e


def assert_valid_sympy_symbol(var_name: str) -> None:
Expand Down
5 changes: 5 additions & 0 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.NumberSymbol):
# Can get here from exp(1) or exact pi
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
Expand Down
6 changes: 3 additions & 3 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def test_pickle_with_temp_equation_file(self):
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])

y_predictions2 = model2.predict(X)
np.testing.assert_array_equal(y_predictions, y_predictions2)
np.testing.assert_array_almost_equal(y_predictions, y_predictions2)

def test_scikit_learn_compatibility(self):
"""Test PySRRegressor compatibility with scikit-learn."""
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def test_multi_output(self):
middle_part_2 = r"""
$y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
$y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
$y_{1} = x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
$y_{1} = x_{0} x_{0} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
"""
true_latex_table_str = "\n\n".join(
self.create_true_latex(part, include_score=True)
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def test_latex_break_long_equation(self):
middle_part = r"""
$y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0} x_{0} x_{0} + x_{0} x_{0} x_{0} x_{0} x_{0} + 3.20 x_{0} - 1.20 x_{1} + x_{1} x_{1} x_{1} + 5.20 \sin{\left(- 2.60 x_{0} + 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0} x_{0} x_{0} + 3.20 x_{0} - 1.20 x_{1} + x_{1} x_{1} x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
"""
true_latex_table_str = (
TRUE_PREAMBLE
Expand Down
44 changes: 36 additions & 8 deletions pysr/test/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,29 @@
import pandas as pd
import sympy

import pysr
from pysr import PySRRegressor, sympy2jax


class TestJAX(unittest.TestCase):
def setUp(self):
np.random.seed(0)
from jax import numpy as jnp

self.jnp = jnp

def test_sympy2jax(self):
from jax import numpy as jnp
from jax import random

x, y, z = sympy.symbols("x y z")
cosx = 1.0 * sympy.cos(x) + y
key = random.PRNGKey(0)
X = random.normal(key, (1000, 2))
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
true = 1.0 * self.jnp.cos(X[:, 0]) + X[:, 1]
f, params = sympy2jax(cosx, [x, y, z])
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
self.assertTrue(self.jnp.all(self.jnp.isclose(f(X, params), true)).item())

def test_pipeline_pandas(self):
from jax import numpy as jnp

X = pd.DataFrame(np.random.randn(100, 10))
y = np.ones(X.shape[0])
Expand All @@ -52,14 +54,12 @@ def test_pipeline_pandas(self):
jformat = model.jax()

np.testing.assert_almost_equal(
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
np.square(np.cos(X.values[:, 1])), # Select feature 1
decimal=3,
)

def test_pipeline(self):
from jax import numpy as jnp

X = np.random.randn(100, 10)
y = np.ones(X.shape[0])
model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
Expand All @@ -81,11 +81,39 @@ def test_pipeline(self):
jformat = model.jax()

np.testing.assert_almost_equal(
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
np.square(np.cos(X[:, 1])), # Select feature 1
decimal=3,
)

def test_avoid_simplification(self):
ex = pysr.export_sympy.pysr2sympy(
"square(exp(sign(0.44796443))) + 1.5 * x1",
feature_names_in=["x1"],
extra_sympy_mappings={"square": lambda x: x**2},
)
f, params = pysr.export_jax.sympy2jax(ex, [sympy.symbols("x1")])
key = np.random.RandomState(0)
X = key.randn(10, 1)
np.testing.assert_almost_equal(
np.array(f(self.jnp.array(X), params)),
np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
decimal=3,
)

def test_issue_656(self):
import sympy

E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
key = np.random.RandomState(0)
X = key.randn(10, 1)
np.testing.assert_almost_equal(
np.array(f(self.jnp.array(X), params)),
np.exp(1) + X[:, 0],
decimal=3,
)

def test_feature_selection_custom_operators(self):
rstate = np.random.RandomState(0)
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
Expand Down
36 changes: 35 additions & 1 deletion pysr/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import sympy

import pysr
from pysr import PySRRegressor, sympy2torch


Expand Down Expand Up @@ -153,10 +154,43 @@ def test_custom_operator(self):
decimal=3,
)

def test_avoid_simplification(self):
# SymPy should not simplify without permission
torch = self.torch
ex = pysr.export_sympy.pysr2sympy(
"square(exp(sign(0.44796443))) + 1.5 * x1",
# ^ Normally this would become exp1 and require
# its own mapping
feature_names_in=["x1"],
extra_sympy_mappings={"square": lambda x: x**2},
)
m = pysr.export_torch.sympy2torch(ex, ["x1"])
rng = np.random.RandomState(0)
X = rng.randn(10, 1)
np.testing.assert_almost_equal(
m(torch.tensor(X)).detach().numpy(),
np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
decimal=3,
)

def test_issue_656(self):
# Should correctly map numeric symbols to floats
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"])
X = np.random.randn(10, 1)
np.testing.assert_almost_equal(
m(self.torch.tensor(X)).detach().numpy(),
np.exp(1) + X[:, 0],
decimal=3,
)

def test_feature_selection_custom_operators(self):
rstate = np.random.RandomState(0)
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720

def cos_approx(x):
return 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720

y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])

model = PySRRegressor(
Expand Down

0 comments on commit 7fc7b82

Please sign in to comment.