diff --git a/pyproject.toml b/pyproject.toml index 8946355d..cd3603a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,10 @@ classifiers = [ ] readme = "README.rst" dependencies = [ + "jax>=0.4,<0.5", "scikit-learn>=1.1, !=1.5.0", "derivative>=0.6.2", + "typing_extensions", ] [project.optional-dependencies] @@ -63,7 +65,6 @@ cvxpy = [ ] sbr = [ "numpyro", - "jax", "arviz==0.17.1", "scipy<1.13.0" ] diff --git a/pysindy/_typing.py b/pysindy/_typing.py index 1ec0c446..d8eec0f7 100644 --- a/pysindy/_typing.py +++ b/pysindy/_typing.py @@ -4,4 +4,11 @@ # In python 3.12, use type statement # https://docs.python.org/3/reference/simple_stmts.html#the-type-statement NpFlt = np.floating[npt.NBitBase] -Float2D = np.ndarray[tuple[int, int], np.dtype[NpFlt]] +FloatDType = np.dtype[np.floating[npt.NBitBase]] +Int1D = np.ndarray[tuple[int], np.dtype[np.int_]] +Float1D = np.ndarray[tuple[int], FloatDType] +Float2D = np.ndarray[tuple[int, int], FloatDType] +Float3D = np.ndarray[tuple[int, int, int], FloatDType] +Float4D = np.ndarray[tuple[int, int, int, int], FloatDType] +Float5D = np.ndarray[tuple[int, int, int, int, int], FloatDType] +FloatND = npt.NDArray[NpFlt] diff --git a/pysindy/feature_library/base.py b/pysindy/feature_library/base.py index 54697da4..8f24dc9a 100644 --- a/pysindy/feature_library/base.py +++ b/pysindy/feature_library/base.py @@ -8,6 +8,7 @@ from typing import Optional from typing import Sequence +import jax import numpy as np from scipy import sparse from sklearn.base import TransformerMixin @@ -144,19 +145,28 @@ def x_sequence_or_item(wrapped_func): @wraps(wrapped_func) def func(self, x, *args, **kwargs): if isinstance(x, Sequence): - xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] + if isinstance(x[0], jax.Array): + xs = x + else: + xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] result = wrapped_func(self, xs, *args, **kwargs) - if isinstance(result, Sequence): # e.g. transform() returns x + # if transform() is a normal "return x" + if isinstance(result, Sequence) and isinstance(result[0], np.ndarray): return [AxesArray(xp, comprehend_axes(xp)) for xp in result] return result # e.g. fit() returns self else: - if not sparse.issparse(x): + if isinstance(x, jax.Array): + + def reconstructor(x): + return x + + elif not sparse.issparse(x) and isinstance(x, np.ndarray): x = AxesArray(x, comprehend_axes(x)) def reconstructor(x): return x - else: # sparse arrays + else: # sparse reconstructor = type(x) axes = comprehend_axes(x) wrap_axes(axes, x) diff --git a/pysindy/feature_library/polynomial_library.py b/pysindy/feature_library/polynomial_library.py index fda52b77..496841e8 100644 --- a/pysindy/feature_library/polynomial_library.py +++ b/pysindy/feature_library/polynomial_library.py @@ -12,6 +12,7 @@ from ..utils import AxesArray from ..utils import comprehend_axes from ..utils import wrap_axes +from ..utils._axis_conventions import AX_COORD from .base import BaseFeatureLibrary from .base import x_sequence_or_item @@ -160,7 +161,7 @@ def get_feature_names(self, input_features=None): return feature_names @x_sequence_or_item - def fit(self, x_full, y=None): + def fit(self, x_full: list[AxesArray], y=None): """ Compute number of output features. @@ -180,7 +181,7 @@ def fit(self, x_full, y=None): "Can't have include_interaction be False and interaction_only" " be True" ) - n_features = x_full[0].shape[x_full[0].ax_coord] + n_features = x_full[0].shape[AX_COORD] combinations = self._combinations( n_features, self.degree, @@ -217,7 +218,7 @@ def transform(self, x_full): axes = comprehend_axes(x) x = x.asformat("csc") wrap_axes(axes, x) - n_features = x.shape[x.ax_coord] + n_features = x.shape[AX_COORD] if n_features != self.n_features_in_: raise ValueError("x shape does not match training shape") diff --git a/pysindy/optimizers/base.py b/pysindy/optimizers/base.py index ea4eef3f..2193952a 100644 --- a/pysindy/optimizers/base.py +++ b/pysindy/optimizers/base.py @@ -4,10 +4,13 @@ import abc import warnings from typing import Callable +from typing import NewType +from typing import Optional from typing import Tuple import numpy as np from scipy import sparse +from sklearn.base import BaseEstimator from sklearn.linear_model import LinearRegression from sklearn.linear_model._base import _preprocess_data from sklearn.utils.extmath import safe_sparse_dot @@ -15,9 +18,13 @@ from sklearn.utils.validation import check_X_y from .._typing import Float2D +from .._typing import FloatDType from ..utils import AxesArray from ..utils import drop_nan_samples +NFeat = NewType("NFeat", int) +NTarget = NewType("NTarget", int) + def _rescale_data(X, y, sample_weight): """Rescale data so as to support sample_weight""" @@ -32,14 +39,17 @@ def _rescale_data(X, y, sample_weight): return X, y -class ComplexityMixin: +class _BaseOptimizer(BaseEstimator, abc.ABC): + coef_: np.ndarray[tuple[NTarget, NFeat], FloatDType] + intercept_: np.ndarray[tuple[NTarget], FloatDType] + @property def complexity(self): check_is_fitted(self) return np.count_nonzero(self.coef_) + np.count_nonzero(self.intercept_) -class BaseOptimizer(LinearRegression, ComplexityMixin): +class BaseOptimizer(LinearRegression, _BaseOptimizer): """ Base class for SINDy optimizers. Subclasses must implement a _reduce method for carrying out the bulk of the work of @@ -89,6 +99,12 @@ class BaseOptimizer(LinearRegression, ComplexityMixin): """ + max_iter: int + normalize_columns: bool + initial_guess: Optional[np.ndarray[tuple[NTarget, NFeat], FloatDType]] + copy_X: bool + unbias: bool + def __init__( self, max_iter=20, diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index b4554d43..a5712634 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -5,33 +5,29 @@ from itertools import repeat from math import comb from typing import cast -from typing import NewType from typing import Optional from typing import TypeVar from typing import Union import cvxpy as cp import numpy as np -from numpy.typing import NBitBase from numpy.typing import NDArray from sklearn.exceptions import ConvergenceWarning +from .._typing import Float1D +from .._typing import Float2D +from .._typing import Float3D +from .._typing import Float4D +from .._typing import Float5D +from .._typing import Int1D from ..feature_library.polynomial_library import n_poly_features from ..feature_library.polynomial_library import PolynomialLibrary from ..utils import reorder_constraints +from .base import FloatDType +from .base import NFeat +from .base import NTarget from .constrained_sr3 import ConstrainedSR3 -AnyFloat = np.dtype[np.floating[NBitBase]] -Int1D = np.ndarray[tuple[int], np.dtype[np.int_]] -Float1D = np.ndarray[tuple[int], AnyFloat] -Float2D = np.ndarray[tuple[int, int], AnyFloat] -Float3D = np.ndarray[tuple[int, int, int], AnyFloat] -Float4D = np.ndarray[tuple[int, int, int, int], AnyFloat] -Float5D = np.ndarray[tuple[int, int, int, int, int], AnyFloat] -FloatND = NDArray[np.floating[NBitBase]] -NFeat = NewType("NFeat", int) -NTarget = NewType("NTarget", int) - class EnstrophyMat: """Pre-compute some useful factors of an enstrophy matrix @@ -601,7 +597,7 @@ def _solve_m_relax_and_split( self, trap_ctr: Float1D, prev_A: Float2D, - coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat], + coef_sparse: np.ndarray[tuple[NFeat, NTarget], FloatDType], ) -> tuple[Float1D, Float2D]: r"""Updates the trap center @@ -693,7 +689,7 @@ def _reduce(self, x, y): self.constraint_lhs = reorder_constraints( self.constraint_lhs, n_features, output_order="feature" ) - coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat] = self.coef_.T + coef_sparse: np.ndarray[tuple[NFeat, NTarget], FloatDType] = self.coef_.T # Print initial values for each term in the optimization if self.verbose: diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index 431a34fc..5c982aa8 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -1,7 +1,10 @@ import sys import warnings +from abc import ABC +from abc import abstractmethod from itertools import product from typing import Collection +from typing import Optional from typing import Sequence from typing import Union @@ -13,9 +16,12 @@ from sklearn.metrics import r2_score from sklearn.pipeline import Pipeline from sklearn.utils.validation import check_is_fitted +from typing_extensions import Self +from .differentiation import BaseDifferentiation from .differentiation import FiniteDifference from .feature_library import PolynomialLibrary +from .feature_library.base import BaseFeatureLibrary try: # Waiting on PEP 690 to lazy import CVXPY from .optimizers import SINDyPI @@ -24,18 +30,119 @@ except ImportError: sindy_pi_flag = False from .optimizers import STLSQ +from .optimizers.base import _BaseOptimizer +from .optimizers.base import BaseOptimizer from .utils import AxesArray from .utils import comprehend_axes from .utils import concat_sample_axis from .utils import drop_nan_samples -from .utils import equations from .utils import SampleConcatter from .utils import validate_control_variables from .utils import validate_input from .utils import validate_no_reshape -class SINDy(BaseEstimator): +class _BaseSINDy(BaseEstimator, ABC): + + feature_library: BaseFeatureLibrary + optimizer: _BaseOptimizer + discrete_time: bool + model: Pipeline + feature_names: Optional[list[str]] + # Hacks to remove later + discrete_time: bool = False + n_control_features_: int = 0 + + @abstractmethod + def fit(self, x, t, *args, **kwargs) -> Self: + ... + + def _fit_shape(self): + """Assign shape attributes for the system that are used post-fit""" + self.n_features_in_ = self.feature_library.n_features_in_ + self.n_output_features_ = self.feature_library.n_output_features_ + if self.feature_names is None: + feature_names = [] + for i in range(self.n_features_in_ - self.n_control_features_): + feature_names.append("x" + str(i)) + for i in range(self.n_control_features_): + feature_names.append("u" + str(i)) + self.feature_names = feature_names + + def equations(self, precision: int = 3) -> list[str]: + """ + Get the right hand sides of the SINDy model equations. + + Parameters + ---------- + precision: int, optional (default 3) + Number of decimal points to include for each coefficient in the + equation. + + Returns + ------- + equations: list of strings + List of strings representing the SINDy model equations for each + input feature. + """ + check_is_fitted(self, "model") + if self.discrete_time: + sys_coord_names = [name + "[k]" for name in self.feature_names] + else: + sys_coord_names = self.feature_names + feat_names = self.feature_library.get_feature_names(sys_coord_names) + + def term(c, name): + rounded_coef = np.round(c, precision) + if rounded_coef == 0: + return "" + else: + return f"{c:.{precision}f} {name}" + + equations = [] + for coef_row in self.optimizer.coef_: + components = [term(c, i) for c, i in zip(coef_row, feat_names)] + eq = " + ".join(filter(bool, components)) + if not eq: + eq = f"{0:.{precision}f}" + equations.append(eq) + + return equations + + def print(self, precision: int = 3, **kwargs) -> None: + """Print the SINDy model equations. + + Parameters + ---------- + lhs: list of strings, optional (default None) + List of variables to print on the left-hand sides of the learned equations. + By default :code:`self.input_features` are used. + + precision: int, optional (default 3) + Precision to be used when printing out model coefficients. + + **kwargs: Additional keyword arguments passed to the builtin print function + """ + eqns = self.equations(precision) + for name, eqn in zip(self.feature_names, eqns, strict=True): + lhs = f"({name})'" + print(f"{lhs} = {eqn}", **kwargs) + + def get_feature_names(self): + """ + Get a list of names of features used by SINDy model. + + Returns + ------- + feats: list + A list of strings giving the names of the features in the feature + library, :code:`self.feature_library`. + """ + check_is_fitted(self, "model") + return self.feature_library.get_feature_names(input_features=self.feature_names) + + +class SINDy(_BaseSINDy): """ Sparse Identification of Nonlinear Dynamical Systems (SINDy). Uses sparse regression to learn a dynamical systems model from measurement data. @@ -150,12 +257,12 @@ class SINDy(BaseEstimator): def __init__( self, - optimizer=None, - feature_library=None, - differentiation_method=None, - feature_names=None, - t_default=1, - discrete_time=False, + optimizer: Optional[BaseOptimizer] = None, + feature_library: Optional[BaseFeatureLibrary] = None, + differentiation_method: Optional[BaseDifferentiation] = None, + feature_names: Optional[list[str]] = None, + t_default: float = 1, + discrete_time: bool = False, ): if optimizer is None: optimizer = STLSQ() @@ -257,17 +364,7 @@ def fit( x_dot = concat_sample_axis(x_dot) self.model = Pipeline(steps) self.model.fit(x, x_dot) - - self.n_features_in_ = self.feature_library.n_features_in_ - self.n_output_features_ = self.feature_library.n_output_features_ - - if self.feature_names is None: - feature_names = [] - for i in range(self.n_features_in_ - self.n_control_features_): - feature_names.append("x" + str(i)) - for i in range(self.n_control_features_): - feature_names.append("u" + str(i)) - self.feature_names = feature_names + self._fit_shape() return self @@ -324,33 +421,6 @@ def predict(self, x, u=None): return result[0] return result - def equations(self, precision=3): - """ - Get the right hand sides of the SINDy model equations. - - Parameters - ---------- - precision: int, optional (default 3) - Number of decimal points to include for each coefficient in the - equation. - - Returns - ------- - equations: list of strings - List of strings representing the SINDy model equations for each - input feature. - """ - check_is_fitted(self, "model") - if self.discrete_time: - base_feature_names = [f + "[k]" for f in self.feature_names] - else: - base_feature_names = self.feature_names - return equations( - self.model, - input_features=base_feature_names, - precision=precision, - ) - def print(self, lhs=None, precision=3, **kwargs): """Print the SINDy model equations. @@ -552,19 +622,6 @@ def coefficients(self): check_is_fitted(self, "model") return self.optimizer.coef_ - def get_feature_names(self): - """ - Get a list of names of features used by SINDy model. - - Returns - ------- - feats: list - A list of strings giving the names of the features in the feature - library, :code:`self.feature_library`. - """ - check_is_fitted(self, "model") - return self.feature_library.get_feature_names(input_features=self.feature_names) - def simulate( self, x0, diff --git a/pysindy/utils/__init__.py b/pysindy/utils/__init__.py index 29e9b905..5fa1f111 100644 --- a/pysindy/utils/__init__.py +++ b/pysindy/utils/__init__.py @@ -5,11 +5,9 @@ from .axes import wrap_axes from .base import capped_simplex_projection from .base import drop_nan_samples -from .base import equations from .base import flatten_2d_tall from .base import get_prox from .base import get_regularization -from .base import print_model from .base import reorder_constraints from .base import supports_multiple_targets from .base import validate_control_variables @@ -55,10 +53,8 @@ "comprehend_axes", "capped_simplex_projection", "drop_nan_samples", - "equations", "get_prox", "get_regularization", - "print_model", "reorder_constraints", "supports_multiple_targets", "validate_control_variables", diff --git a/pysindy/utils/_axis_conventions.py b/pysindy/utils/_axis_conventions.py new file mode 100644 index 00000000..98a7c582 --- /dev/null +++ b/pysindy/utils/_axis_conventions.py @@ -0,0 +1,2 @@ +AX_TIME = -2 +AX_COORD = -1 diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 3d5c9ee3..95cbf40a 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -1,5 +1,4 @@ import warnings -from itertools import repeat from typing import Callable from typing import Sequence from typing import Union @@ -62,7 +61,7 @@ def validate_input(x, t=T_DEFAULT): return x_new -def validate_no_reshape(x, t=T_DEFAULT): +def validate_no_reshape(x, t: Union[float, np.ndarray, object] = T_DEFAULT): """Check types and numerical sensibility of arguments. Args: @@ -73,7 +72,7 @@ def validate_no_reshape(x, t=T_DEFAULT): x as 2D array, with time dimension on first axis and coordinate index on second axis. """ - if not isinstance(x, np.ndarray): + if not hasattr(x, "shape"): raise TypeError("Input value must be array-like") check_array(x, ensure_2d=False, allow_nd=True) @@ -85,7 +84,7 @@ def validate_no_reshape(x, t=T_DEFAULT): if t <= 0: raise ValueError("t must be positive") # Only apply these tests if t is array-like - elif isinstance(t, np.ndarray): + elif hasattr(t, "shape"): if not len(t) == x.shape[-2]: raise ValueError("Length of t should match x.shape[-2].") if not np.all(t[:-1] < t[1:]): @@ -290,69 +289,6 @@ def f(x): return np.maximum(np.minimum(trimming_array - x, 1.0), 0.0) -def print_model( - coef, - input_features, - errors=None, - intercept=None, - error_intercept=None, - precision=3, - pm="±", -): - """ - Args: - coef: - input_features: - errors: - intercept: - sigma_intercept: - precision: - pm: - Returns: - """ - - def term(c, sigma, name): - rounded_coef = np.round(c, precision) - if rounded_coef == 0 and sigma is None: - return "" - elif sigma is None: - return f"{c:.{precision}f} {name}" - elif rounded_coef == 0 and np.round(sigma, precision) == 0: - return "" - else: - return f"({c:.{precision}f} {pm} {sigma:.{precision}f}) {name}" - - errors = errors if errors is not None else repeat(None) - components = [term(c, e, i) for c, e, i in zip(coef, errors, input_features)] - eq = " + ".join(filter(bool, components)) - - if not eq or intercept or error_intercept is not None: - intercept = intercept or 0 - intercept_str = term(intercept, error_intercept, "").strip() - if eq and intercept_str: - eq += " + " - eq += intercept_str - elif not eq: - eq = f"{intercept:.{precision}f}" - return eq - - -def equations(pipeline, input_features=None, precision=3, input_fmt=None): - input_features = pipeline.steps[0][1].get_feature_names(input_features) - if input_fmt: - input_features = [input_fmt(i) for i in input_features] - coef = pipeline.steps[-1][1].coef_ - intercept = pipeline.steps[-1][1].intercept_ - if np.isscalar(intercept): - intercept = intercept * np.ones(coef.shape[0]) - return [ - print_model( - coef[i], input_features, intercept=intercept[i], precision=precision - ) - for i in range(coef.shape[0]) - ] - - def supports_multiple_targets(estimator): """Checks whether estimator supports multiple targets.""" if isinstance(estimator, MultiOutputMixin): diff --git a/test/conftest.py b/test/conftest.py index 59fb61f7..ead11adf 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,7 @@ """ from pathlib import Path +import jax import numpy as np import pytest from scipy.integrate import solve_ivp @@ -20,6 +21,8 @@ from pysindy.utils.odes import lorenz from pysindy.utils.odes import lorenz_control +jax.config.update("jax_platform_name", "cpu") + def pytest_addoption(parser): parser.addoption(