Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More extensive typing stubs and associated refactoring #609

Merged
merged 30 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
018c3a1
refactor: sympy typing issue
MilesCranmer Apr 28, 2024
4bcc280
refactor: create `ArrayLike` type for type checking
MilesCranmer Apr 28, 2024
cb76a81
refactor: declare julia as `Any` to avoid typing issues
MilesCranmer Apr 28, 2024
dca10d6
fix: potential issue with non-standard random states
MilesCranmer Apr 28, 2024
2bd7782
refactor: improved type inference in return values
MilesCranmer Apr 28, 2024
7909e90
refactor: more type declarations
MilesCranmer Apr 28, 2024
7113eed
style: use pandas indexing for return values
MilesCranmer Apr 28, 2024
583beaf
refactor: typings for sympy export
MilesCranmer Apr 28, 2024
c2379a1
build: add mypy to dev environment
MilesCranmer Apr 28, 2024
47ad683
refactor: NDArray to ndarray for string type
MilesCranmer Apr 28, 2024
04aa23c
style: move attribute types to after docstring
MilesCranmer Apr 28, 2024
a5eaab9
refactor: help with type inference of `get_best`
MilesCranmer Apr 28, 2024
526d334
fix: type inference issue in return value of get_best
MilesCranmer Apr 28, 2024
96e5a0f
fix: upper bound of randint for windows
MilesCranmer Apr 28, 2024
dafd19b
refactor: move more latex code to export_latex
MilesCranmer Apr 28, 2024
f2a280c
Merge tag 'v0.18.4' into cleanup
MilesCranmer May 5, 2024
b958ebf
test: add stubs to dev deps
MilesCranmer May 5, 2024
fd4c500
fix: variety of typing information
MilesCranmer May 5, 2024
cd925dd
build: remove typing_extensions
MilesCranmer May 5, 2024
9854909
fix: selection_mask to be bool array
MilesCranmer May 5, 2024
810bea9
test: more typing info
MilesCranmer May 5, 2024
483a9b8
ci: ignore lock files
MilesCranmer May 5, 2024
505af8d
fix: boolean selection masks in pandas eval
MilesCranmer May 5, 2024
530ae99
refactor: runtime parameters into dataclass
MilesCranmer May 5, 2024
88d93a1
feat: add helper for specifying dtype of jl Array
MilesCranmer May 5, 2024
9f3b918
refactor: more robust parsing of operators
MilesCranmer May 5, 2024
7021459
docs: separate runtime params
MilesCranmer May 5, 2024
db5f4d5
refactor: standardize constant
MilesCranmer May 5, 2024
76f0b3f
fix: type compat for 3.8
MilesCranmer May 5, 2024
291dc85
Merge branch 'master' into cleanup
MilesCranmer Jun 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ site
**/*.code-workspace
**/*.tar.gz
venv
requirements-dev.lock
requirements.lock
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ dependencies:
- scikit-learn>=1.0.0,<2.0.0
- pyjuliacall>=0.9.15,<0.10.0
- click>=7.0.0,<9.0.0
- typing_extensions>=4.0.0,<5.0.0
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@ dev-dependencies = [
"pre-commit>=3.7.0",
"ipython>=8.23.0",
"ipykernel>=6.29.4",
"mypy>=1.10.0",
"jax[cpu]>=0.4.26",
"torch>=2.3.0",
"pandas-stubs>=2.2.1.240316",
"types-pytz>=2024.1.0.20240417",
"types-openpyxl>=3.1.0.20240428",
]
21 changes: 17 additions & 4 deletions pysr/denoising.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Functions for denoising data during preprocessing."""

from typing import Optional, Tuple, cast

import numpy as np
from numpy import ndarray


def denoise(X, y, Xresampled=None, random_state=None):
def denoise(
X: ndarray,
y: ndarray,
Xresampled: Optional[ndarray] = None,
random_state: Optional[np.random.RandomState] = None,
) -> Tuple[ndarray, ndarray]:
"""Denoise the dataset using a Gaussian process."""
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
Expand All @@ -15,12 +23,17 @@ def denoise(X, y, Xresampled=None, random_state=None):
gpr.fit(X, y)

if Xresampled is not None:
return Xresampled, gpr.predict(Xresampled)
return Xresampled, cast(ndarray, gpr.predict(Xresampled))

return X, gpr.predict(X)
return X, cast(ndarray, gpr.predict(X))


def multi_denoise(X, y, Xresampled=None, random_state=None):
def multi_denoise(
X: ndarray,
y: ndarray,
Xresampled: Optional[ndarray] = None,
random_state: Optional[np.random.RandomState] = None,
):
"""Perform `denoise` along each column of `y` independently."""
y = np.stack(
[
Expand Down
12 changes: 12 additions & 0 deletions pysr/export_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,15 @@ def sympy2multilatextable(
]

return "\n\n".join(latex_tables)


def with_preamble(table_string: str) -> str:
preamble_string = [
r"\usepackage{breqn}",
r"\usepackage{booktabs}",
"",
"...",
"",
table_string,
]
return "\n".join(preamble_string)
12 changes: 10 additions & 2 deletions pysr/export_numpy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Code for exporting discovered expressions to numpy"""

import warnings
from typing import List, Union

import numpy as np
import pandas as pd
from sympy import lambdify
from numpy.typing import NDArray
from sympy import Expr, Symbol, lambdify


def sympy2numpy(eqn, sympy_symbols, *, selection=None):
Expand All @@ -14,6 +16,10 @@ def sympy2numpy(eqn, sympy_symbols, *, selection=None):
class CallableEquation:
"""Simple wrapper for numpy lambda functions built with sympy"""

_sympy: Expr
_sympy_symbols: List[Symbol]
_selection: Union[NDArray[np.bool_], None]

def __init__(self, eqn, sympy_symbols, selection=None):
self._sympy = eqn
self._sympy_symbols = sympy_symbols
Expand All @@ -29,15 +35,17 @@ def __call__(self, X):
return self._lambda(
**{k: X[k].values for k in map(str, self._sympy_symbols)}
) * np.ones(expected_shape)

if self._selection is not None:
if X.shape[1] != len(self._selection):
if X.shape[1] != self._selection.sum():
warnings.warn(
"`X` should be of shape (n_samples, len(self._selection)). "
"Automatically filtering `X` to selection. "
"Note: Filtered `X` column order may not match column order in fit "
"this may lead to incorrect predictions and other errors."
)
X = X[:, self._selection]

return self._lambda(*X.T) * np.ones(expected_shape)

@property
Expand Down
12 changes: 7 additions & 5 deletions pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sympy
from sympy import sympify

from .utils import ArrayLike

sympy_mappings = {
"div": lambda x, y: x / y,
"mult": lambda x, y: x * y,
Expand All @@ -30,8 +32,8 @@
"acosh": lambda x: sympy.acosh(x),
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
"asinh": sympy.asinh,
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"abs": abs,
"mod": sympy.Mod,
"erf": sympy.erf,
Expand Down Expand Up @@ -60,21 +62,21 @@


def create_sympy_symbols_map(
feature_names_in: List[str],
feature_names_in: ArrayLike[str],
) -> Dict[str, sympy.Symbol]:
return {variable: sympy.Symbol(variable) for variable in feature_names_in}


def create_sympy_symbols(
feature_names_in: List[str],
feature_names_in: ArrayLike[str],
) -> List[sympy.Symbol]:
return [sympy.Symbol(variable) for variable in feature_names_in]


def pysr2sympy(
equation: str,
*,
feature_names_in: Optional[List[str]] = None,
feature_names_in: Optional[ArrayLike[str]] = None,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
if feature_names_in is None:
Expand Down
22 changes: 19 additions & 3 deletions pysr/feature_selection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
"""Functions for doing feature selection during preprocessing."""

from typing import Optional, cast

import numpy as np
from numpy import ndarray
from numpy.typing import NDArray

from .utils import ArrayLike


def run_feature_selection(X, y, select_k_features, random_state=None):
def run_feature_selection(
X: ndarray,
y: ndarray,
select_k_features: int,
random_state: Optional[np.random.RandomState] = None,
) -> NDArray[np.bool_]:
"""
Find most important features.

Expand All @@ -21,11 +32,16 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
selector = SelectFromModel(
clf, threshold=-np.inf, max_features=select_k_features, prefit=True
)
return selector.get_support(indices=True)
return cast(NDArray[np.bool_], selector.get_support(indices=False))


# Function has not been removed only due to usage in module tests
def _handle_feature_selection(X, select_k_features, y, variable_names):
def _handle_feature_selection(
X: ndarray,
select_k_features: Optional[int],
y: ndarray,
variable_names: ArrayLike[str],
):
if select_k_features is not None:
selection = run_feature_selection(X, y, select_k_features)
print(f"Using features {[variable_names[i] for i in selection]}")
Expand Down
22 changes: 17 additions & 5 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Functions for initializing the Julia environment and installing deps."""

from typing import Any, Callable, Union, cast

import numpy as np
from juliacall import convert as jl_convert # type: ignore
from numpy.typing import NDArray

from .deprecated import init_julia, install
from .julia_import import jl

jl_convert = cast(Callable[[Any, Any], Any], jl_convert)

jl.seval("using Serialization: Serialization")
jl.seval("using PythonCall: PythonCall")

Expand All @@ -22,24 +27,31 @@ def _escape_filename(filename):
return str_repr


def _load_cluster_manager(cluster_manager):
def _load_cluster_manager(cluster_manager: str):
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
return jl.seval(f"addprocs_{cluster_manager}")


def jl_array(x):
def jl_array(x, dtype=None):
if x is None:
return None
return jl_convert(jl.Array, x)
elif dtype is None:
return jl_convert(jl.Array, x)
else:
return jl_convert(jl.Array[dtype], x)


def jl_is_function(f) -> bool:
return cast(bool, jl.seval("op -> op isa Function")(f))


def jl_serialize(obj):
def jl_serialize(obj: Any) -> NDArray[np.uint8]:
buf = jl.IOBuffer()
Serialization.serialize(buf, obj)
return np.array(jl.take_b(buf))


def jl_deserialize(s):
def jl_deserialize(s: Union[NDArray[np.uint8], None]):
if s is None:
return s
buf = jl.IOBuffer()
Expand Down
5 changes: 5 additions & 0 deletions pysr/julia_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import sys
import warnings
from types import ModuleType
from typing import cast

# Check if JuliaCall is already loaded, and if so, warn the user
# about the relevant environment variables. If not loaded,
Expand Down Expand Up @@ -42,6 +44,9 @@

from juliacall import Main as jl # type: ignore

jl = cast(ModuleType, jl)


jl_version = (jl.VERSION.major, jl.VERSION.minor, jl.VERSION.patch)

jl.seval("using SymbolicRegression")
Expand Down
Loading
Loading