From 79b5515d4ecbf2ef2a960cac8dba089c8b8c8178 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:53:10 -0700 Subject: [PATCH 1/9] feat: Create base class for all SINDy problems (except SINDyPI) This will eventually enable SSSindy.print, as well as a simpler discrete and SINDyPI class --- pysindy/feature_library/__init__.py | 2 + pysindy/pysindy.py | 125 ++++++++++++++++++---------- 2 files changed, 85 insertions(+), 42 deletions(-) diff --git a/pysindy/feature_library/__init__.py b/pysindy/feature_library/__init__.py index 47471956..41c100ee 100644 --- a/pysindy/feature_library/__init__.py +++ b/pysindy/feature_library/__init__.py @@ -1,3 +1,4 @@ +from .base import BaseFeatureLibrary from .base import ConcatLibrary from .base import TensoredLibrary from .custom_library import CustomLibrary @@ -11,6 +12,7 @@ from .weak_pde_library import WeakPDELibrary __all__ = [ + "BaseFeatureLibrary", "ConcatLibrary", "TensoredLibrary", "GeneralizedLibrary", diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index 431a34fc..4e792fd5 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -1,9 +1,11 @@ +from abc import ABC, abstractmethod import sys import warnings from itertools import product from typing import Collection from typing import Sequence from typing import Union +from typing import Optional import numpy as np from scipy.integrate import odeint @@ -13,9 +15,10 @@ from sklearn.metrics import r2_score from sklearn.pipeline import Pipeline from sklearn.utils.validation import check_is_fitted +from typing_extensions import Self from .differentiation import FiniteDifference -from .feature_library import PolynomialLibrary +from .feature_library import PolynomialLibrary, BaseFeatureLibrary try: # Waiting on PEP 690 to lazy import CVXPY from .optimizers import SINDyPI @@ -23,7 +26,7 @@ sindy_pi_flag = True except ImportError: sindy_pi_flag = False -from .optimizers import STLSQ +from .optimizers import STLSQ, BaseOptimizer from .utils import AxesArray from .utils import comprehend_axes from .utils import concat_sample_axis @@ -35,7 +38,84 @@ from .utils import validate_no_reshape -class SINDy(BaseEstimator): +class _BaseSINDy(BaseEstimator, ABC): + + feature_library: BaseFeatureLibrary + optimizer: BaseOptimizer + discrete_time: bool + model: Pipeline + feature_names: Optional[list[str]] + + @abstractmethod + def fit(self, x, t, *args, **kwargs) -> Self: + ... + + + def equations(self, precision: int) -> list[str]: + """ + Get the right hand sides of the SINDy model equations. + + Parameters + ---------- + precision: int, optional (default 3) + Number of decimal points to include for each coefficient in the + equation. + + Returns + ------- + equations: list of strings + List of strings representing the SINDy model equations for each + input feature. + """ + check_is_fitted(self, "model") + if self.discrete_time: + base_feature_names = [name + "[k]" for name in self.feature_names] + else: + base_feature_names = self.feature_names + return equations( + self.model, + input_features=base_feature_names, + precision=precision, + ) + + + def print(self, precision: int, **kwargs) -> None: + """Print the SINDy model equations. + + Parameters + ---------- + lhs: list of strings, optional (default None) + List of variables to print on the left-hand sides of the learned equations. + By default :code:`self.input_features` are used. + + precision: int, optional (default 3) + Precision to be used when printing out model coefficients. + + **kwargs: Additional keyword arguments passed to the builtin print function + """ + eqns = self.equations(precision) + for name, eqn in zip(self.feature_names, eqns, strict=True): + if self.discrete_time: + lhs = f"({name})[k+1]" + else: + lhs = f"({name})'" + print(f"{lhs} = {eqn}", **kwargs) + + def get_feature_names(self): + """ + Get a list of names of features used by SINDy model. + + Returns + ------- + feats: list + A list of strings giving the names of the features in the feature + library, :code:`self.feature_library`. + """ + check_is_fitted(self, "model") + return self.feature_library.get_feature_names(input_features=self.feature_names) + + +class SINDy(_BaseSINDy): """ Sparse Identification of Nonlinear Dynamical Systems (SINDy). Uses sparse regression to learn a dynamical systems model from measurement data. @@ -324,32 +404,6 @@ def predict(self, x, u=None): return result[0] return result - def equations(self, precision=3): - """ - Get the right hand sides of the SINDy model equations. - - Parameters - ---------- - precision: int, optional (default 3) - Number of decimal points to include for each coefficient in the - equation. - - Returns - ------- - equations: list of strings - List of strings representing the SINDy model equations for each - input feature. - """ - check_is_fitted(self, "model") - if self.discrete_time: - base_feature_names = [f + "[k]" for f in self.feature_names] - else: - base_feature_names = self.feature_names - return equations( - self.model, - input_features=base_feature_names, - precision=precision, - ) def print(self, lhs=None, precision=3, **kwargs): """Print the SINDy model equations. @@ -552,19 +606,6 @@ def coefficients(self): check_is_fitted(self, "model") return self.optimizer.coef_ - def get_feature_names(self): - """ - Get a list of names of features used by SINDy model. - - Returns - ------- - feats: list - A list of strings giving the names of the features in the feature - library, :code:`self.feature_library`. - """ - check_is_fitted(self, "model") - return self.feature_library.get_feature_names(input_features=self.feature_names) - def simulate( self, x0, From 809eb61291cdab358df0a19e9c8e7a26c8393904 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:07:17 -0700 Subject: [PATCH 2/9] cln: Move SINDy.print() and equations() into the BaseSINDy SINDy.equations(), equations(), print_model() were written with extra generality that never ended up being used. This commit removes such generality and throws away these indirected helper functions in favor of a single method. --- pysindy/pysindy.py | 43 ++++++++++++++++++----------- pysindy/utils/__init__.py | 2 -- pysindy/utils/base.py | 57 +++++---------------------------------- 3 files changed, 33 insertions(+), 69 deletions(-) diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index 4e792fd5..b1208426 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -1,11 +1,12 @@ -from abc import ABC, abstractmethod import sys import warnings +from abc import ABC +from abc import abstractmethod from itertools import product from typing import Collection +from typing import Optional from typing import Sequence from typing import Union -from typing import Optional import numpy as np from scipy.integrate import odeint @@ -18,7 +19,8 @@ from typing_extensions import Self from .differentiation import FiniteDifference -from .feature_library import PolynomialLibrary, BaseFeatureLibrary +from .feature_library import BaseFeatureLibrary +from .feature_library import PolynomialLibrary try: # Waiting on PEP 690 to lazy import CVXPY from .optimizers import SINDyPI @@ -31,7 +33,6 @@ from .utils import comprehend_axes from .utils import concat_sample_axis from .utils import drop_nan_samples -from .utils import equations from .utils import SampleConcatter from .utils import validate_control_variables from .utils import validate_input @@ -39,7 +40,7 @@ class _BaseSINDy(BaseEstimator, ABC): - + feature_library: BaseFeatureLibrary optimizer: BaseOptimizer discrete_time: bool @@ -50,8 +51,7 @@ class _BaseSINDy(BaseEstimator, ABC): def fit(self, x, t, *args, **kwargs) -> Self: ... - - def equations(self, precision: int) -> list[str]: + def equations(self, precision: int = 3) -> list[str]: """ Get the right hand sides of the SINDy model equations. @@ -69,17 +69,29 @@ def equations(self, precision: int) -> list[str]: """ check_is_fitted(self, "model") if self.discrete_time: - base_feature_names = [name + "[k]" for name in self.feature_names] + sys_coord_names = [name + "[k]" for name in self.feature_names] else: - base_feature_names = self.feature_names - return equations( - self.model, - input_features=base_feature_names, - precision=precision, - ) + sys_coord_names = self.feature_names + feat_names = self.feature_library.get_feature_names(sys_coord_names) + def term(c, name): + rounded_coef = np.round(c, precision) + if rounded_coef == 0: + return "" + else: + return f"{c:.{precision}f} {name}" + + equations = [] + for coef_row in self.optimizer.coef_: + components = [term(c, i) for c, i in zip(coef_row, feat_names)] + eq = " + ".join(filter(bool, components)) + if not eq: + eq = f"{0:.{precision}f}" + equations.append(eq) - def print(self, precision: int, **kwargs) -> None: + return equations + + def print(self, precision: int = 3, **kwargs) -> None: """Print the SINDy model equations. Parameters @@ -404,7 +416,6 @@ def predict(self, x, u=None): return result[0] return result - def print(self, lhs=None, precision=3, **kwargs): """Print the SINDy model equations. diff --git a/pysindy/utils/__init__.py b/pysindy/utils/__init__.py index 29e9b905..fea2db62 100644 --- a/pysindy/utils/__init__.py +++ b/pysindy/utils/__init__.py @@ -5,7 +5,6 @@ from .axes import wrap_axes from .base import capped_simplex_projection from .base import drop_nan_samples -from .base import equations from .base import flatten_2d_tall from .base import get_prox from .base import get_regularization @@ -55,7 +54,6 @@ "comprehend_axes", "capped_simplex_projection", "drop_nan_samples", - "equations", "get_prox", "get_regularization", "print_model", diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 3d5c9ee3..592a296a 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -1,5 +1,4 @@ import warnings -from itertools import repeat from typing import Callable from typing import Sequence from typing import Union @@ -293,66 +292,22 @@ def f(x): def print_model( coef, input_features, - errors=None, - intercept=None, - error_intercept=None, precision=3, - pm="±", ): - """ - Args: - coef: - input_features: - errors: - intercept: - sigma_intercept: - precision: - pm: - Returns: - """ - - def term(c, sigma, name): + def term(c, name): rounded_coef = np.round(c, precision) - if rounded_coef == 0 and sigma is None: - return "" - elif sigma is None: - return f"{c:.{precision}f} {name}" - elif rounded_coef == 0 and np.round(sigma, precision) == 0: + if rounded_coef == 0: return "" else: - return f"({c:.{precision}f} {pm} {sigma:.{precision}f}) {name}" + return f"{c:.{precision}f} {name}" - errors = errors if errors is not None else repeat(None) - components = [term(c, e, i) for c, e, i in zip(coef, errors, input_features)] + components = [term(c, i) for c, i in zip(coef, input_features)] eq = " + ".join(filter(bool, components)) - - if not eq or intercept or error_intercept is not None: - intercept = intercept or 0 - intercept_str = term(intercept, error_intercept, "").strip() - if eq and intercept_str: - eq += " + " - eq += intercept_str - elif not eq: - eq = f"{intercept:.{precision}f}" + if not eq: + eq = f"{0:.{precision}f}" return eq -def equations(pipeline, input_features=None, precision=3, input_fmt=None): - input_features = pipeline.steps[0][1].get_feature_names(input_features) - if input_fmt: - input_features = [input_fmt(i) for i in input_features] - coef = pipeline.steps[-1][1].coef_ - intercept = pipeline.steps[-1][1].intercept_ - if np.isscalar(intercept): - intercept = intercept * np.ones(coef.shape[0]) - return [ - print_model( - coef[i], input_features, intercept=intercept[i], precision=precision - ) - for i in range(coef.shape[0]) - ] - - def supports_multiple_targets(estimator): """Checks whether estimator supports multiple targets.""" if isinstance(estimator, MultiOutputMixin): From 02221e04ca72ca536f018ff45e7306d623176733 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:50:54 -0700 Subject: [PATCH 3/9] cln: remove discrete special case from BaseSINDy.print() (still in SINDy.print()) --- pysindy/pysindy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index b1208426..e3b8de0b 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -46,6 +46,9 @@ class _BaseSINDy(BaseEstimator, ABC): discrete_time: bool model: Pipeline feature_names: Optional[list[str]] + # Hacks to remove later + discrete_time: bool = False + n_control_features_: int = 0 @abstractmethod def fit(self, x, t, *args, **kwargs) -> Self: @@ -107,10 +110,7 @@ def print(self, precision: int = 3, **kwargs) -> None: """ eqns = self.equations(precision) for name, eqn in zip(self.feature_names, eqns, strict=True): - if self.discrete_time: - lhs = f"({name})[k+1]" - else: - lhs = f"({name})'" + lhs = f"({name})'" print(f"{lhs} = {eqn}", **kwargs) def get_feature_names(self): From 92834842a2654da9b92763dccae2b17314c02426 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:05:58 -0700 Subject: [PATCH 4/9] cln: extract common fit functionality to base --- pysindy/pysindy.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index e3b8de0b..35d9cda0 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -18,6 +18,7 @@ from sklearn.utils.validation import check_is_fitted from typing_extensions import Self +from .differentiation import BaseDifferentiation from .differentiation import FiniteDifference from .feature_library import BaseFeatureLibrary from .feature_library import PolynomialLibrary @@ -54,6 +55,18 @@ class _BaseSINDy(BaseEstimator, ABC): def fit(self, x, t, *args, **kwargs) -> Self: ... + def _fit_shape(self): + """Assign shape attributes for the system that are used post-fit""" + self.n_features_in_ = self.feature_library.n_features_in_ + self.n_output_features_ = self.feature_library.n_output_features_ + if self.feature_names is None: + feature_names = [] + for i in range(self.n_features_in_ - self.n_control_features_): + feature_names.append("x" + str(i)) + for i in range(self.n_control_features_): + feature_names.append("u" + str(i)) + self.feature_names = feature_names + def equations(self, precision: int = 3) -> list[str]: """ Get the right hand sides of the SINDy model equations. @@ -242,12 +255,12 @@ class SINDy(_BaseSINDy): def __init__( self, - optimizer=None, - feature_library=None, - differentiation_method=None, - feature_names=None, - t_default=1, - discrete_time=False, + optimizer: Optional[BaseOptimizer] = None, + feature_library: Optional[BaseFeatureLibrary] = None, + differentiation_method: Optional[BaseDifferentiation] = None, + feature_names: Optional[list[str]] = None, + t_default: float = 1, + discrete_time: bool = False, ): if optimizer is None: optimizer = STLSQ() @@ -349,17 +362,7 @@ def fit( x_dot = concat_sample_axis(x_dot) self.model = Pipeline(steps) self.model.fit(x, x_dot) - - self.n_features_in_ = self.feature_library.n_features_in_ - self.n_output_features_ = self.feature_library.n_output_features_ - - if self.feature_names is None: - feature_names = [] - for i in range(self.n_features_in_ - self.n_control_features_): - feature_names.append("x" + str(i)) - for i in range(self.n_control_features_): - feature_names.append("u" + str(i)) - self.feature_names = feature_names + self._fit_shape() return self From 2b3871b458c03578e6f41d537acd28c4b63c0c90 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:38:06 -0700 Subject: [PATCH 5/9] bld: add typing_extensions requirement Also: * fix sphinx cross-linking error * remove unused print_model --- pyproject.toml | 1 + pysindy/feature_library/__init__.py | 2 -- pysindy/pysindy.py | 2 +- pysindy/utils/__init__.py | 2 -- pysindy/utils/base.py | 19 ------------------- 5 files changed, 2 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8946355d..26e9d4ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ readme = "README.rst" dependencies = [ "scikit-learn>=1.1, !=1.5.0", "derivative>=0.6.2", + "typing_extensions", ] [project.optional-dependencies] diff --git a/pysindy/feature_library/__init__.py b/pysindy/feature_library/__init__.py index 41c100ee..47471956 100644 --- a/pysindy/feature_library/__init__.py +++ b/pysindy/feature_library/__init__.py @@ -1,4 +1,3 @@ -from .base import BaseFeatureLibrary from .base import ConcatLibrary from .base import TensoredLibrary from .custom_library import CustomLibrary @@ -12,7 +11,6 @@ from .weak_pde_library import WeakPDELibrary __all__ = [ - "BaseFeatureLibrary", "ConcatLibrary", "TensoredLibrary", "GeneralizedLibrary", diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index 35d9cda0..de502b5e 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -20,8 +20,8 @@ from .differentiation import BaseDifferentiation from .differentiation import FiniteDifference -from .feature_library import BaseFeatureLibrary from .feature_library import PolynomialLibrary +from .feature_library.base import BaseFeatureLibrary try: # Waiting on PEP 690 to lazy import CVXPY from .optimizers import SINDyPI diff --git a/pysindy/utils/__init__.py b/pysindy/utils/__init__.py index fea2db62..5fa1f111 100644 --- a/pysindy/utils/__init__.py +++ b/pysindy/utils/__init__.py @@ -8,7 +8,6 @@ from .base import flatten_2d_tall from .base import get_prox from .base import get_regularization -from .base import print_model from .base import reorder_constraints from .base import supports_multiple_targets from .base import validate_control_variables @@ -56,7 +55,6 @@ "drop_nan_samples", "get_prox", "get_regularization", - "print_model", "reorder_constraints", "supports_multiple_targets", "validate_control_variables", diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 592a296a..ef9a5af9 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -289,25 +289,6 @@ def f(x): return np.maximum(np.minimum(trimming_array - x, 1.0), 0.0) -def print_model( - coef, - input_features, - precision=3, -): - def term(c, name): - rounded_coef = np.round(c, precision) - if rounded_coef == 0: - return "" - else: - return f"{c:.{precision}f} {name}" - - components = [term(c, i) for c, i in zip(coef, input_features)] - eq = " + ".join(filter(bool, components)) - if not eq: - eq = f"{0:.{precision}f}" - return eq - - def supports_multiple_targets(estimator): """Checks whether estimator supports multiple targets.""" if isinstance(estimator, MultiOutputMixin): From dac37020c0a1857181fd95705814ff0d3cd0cd64 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 7 Oct 2024 08:49:08 -0700 Subject: [PATCH 6/9] CLN: Extract _BaseOptimizer, the most basal SINDy optimizer This includes things like reduce() and coef_ --- pysindy/feature_library/base.py | 3 ++- pysindy/feature_library/polynomial_library.py | 2 +- pysindy/optimizers/base.py | 20 +++++++++++++++++-- pysindy/optimizers/trapping_sr3.py | 6 +++--- pysindy/pysindy.py | 6 ++++-- pysindy/utils/base.py | 6 +++--- 6 files changed, 31 insertions(+), 12 deletions(-) diff --git a/pysindy/feature_library/base.py b/pysindy/feature_library/base.py index 54697da4..671386cb 100644 --- a/pysindy/feature_library/base.py +++ b/pysindy/feature_library/base.py @@ -146,7 +146,8 @@ def func(self, x, *args, **kwargs): if isinstance(x, Sequence): xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] result = wrapped_func(self, xs, *args, **kwargs) - if isinstance(result, Sequence): # e.g. transform() returns x + # if transform() is a normal "return x" + if isinstance(result, Sequence) and isinstance(result[0], np.ndarray): return [AxesArray(xp, comprehend_axes(xp)) for xp in result] return result # e.g. fit() returns self else: diff --git a/pysindy/feature_library/polynomial_library.py b/pysindy/feature_library/polynomial_library.py index fda52b77..72e357a7 100644 --- a/pysindy/feature_library/polynomial_library.py +++ b/pysindy/feature_library/polynomial_library.py @@ -160,7 +160,7 @@ def get_feature_names(self, input_features=None): return feature_names @x_sequence_or_item - def fit(self, x_full, y=None): + def fit(self, x_full: list[AxesArray], y=None): """ Compute number of output features. diff --git a/pysindy/optimizers/base.py b/pysindy/optimizers/base.py index 01526f45..56995a22 100644 --- a/pysindy/optimizers/base.py +++ b/pysindy/optimizers/base.py @@ -4,10 +4,13 @@ import abc import warnings from typing import Callable +from typing import NewType from typing import Tuple import numpy as np +from numpy.typing import NBitBase from scipy import sparse +from sklearn.base import BaseEstimator from sklearn.linear_model import LinearRegression from sklearn.linear_model._base import _preprocess_data from sklearn.utils.extmath import safe_sparse_dot @@ -17,6 +20,10 @@ from ..utils import AxesArray from ..utils import drop_nan_samples +AnyFloat = np.dtype[np.floating[NBitBase]] +NFeat = NewType("NFeat", int) +NTarget = NewType("NTarget", int) + def _rescale_data(X, y, sample_weight): """Rescale data so as to support sample_weight""" @@ -31,14 +38,17 @@ def _rescale_data(X, y, sample_weight): return X, y -class ComplexityMixin: +class _BaseOptimizer(BaseEstimator, abc.ABC): + coef_: np.ndarray[tuple[NTarget, NFeat], AnyFloat] + intercept_: np.ndarray[tuple[NTarget], AnyFloat] + @property def complexity(self): check_is_fitted(self) return np.count_nonzero(self.coef_) + np.count_nonzero(self.intercept_) -class BaseOptimizer(LinearRegression, ComplexityMixin): +class BaseOptimizer(LinearRegression, _BaseOptimizer): """ Base class for SINDy optimizers. Subclasses must implement a _reduce method for carrying out the bulk of the work of @@ -88,6 +98,12 @@ class BaseOptimizer(LinearRegression, ComplexityMixin): """ + max_iter: int + normalize_columns: bool + initial_guess: None | np.ndarray[tuple[NTarget, NFeat], AnyFloat] + copy_X: bool + unbias: bool + def __init__( self, max_iter=20, diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index b4554d43..fb68e1e6 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -19,9 +19,11 @@ from ..feature_library.polynomial_library import n_poly_features from ..feature_library.polynomial_library import PolynomialLibrary from ..utils import reorder_constraints +from .base import AnyFloat +from .base import NFeat +from .base import NTarget from .constrained_sr3 import ConstrainedSR3 -AnyFloat = np.dtype[np.floating[NBitBase]] Int1D = np.ndarray[tuple[int], np.dtype[np.int_]] Float1D = np.ndarray[tuple[int], AnyFloat] Float2D = np.ndarray[tuple[int, int], AnyFloat] @@ -29,8 +31,6 @@ Float4D = np.ndarray[tuple[int, int, int, int], AnyFloat] Float5D = np.ndarray[tuple[int, int, int, int, int], AnyFloat] FloatND = NDArray[np.floating[NBitBase]] -NFeat = NewType("NFeat", int) -NTarget = NewType("NTarget", int) class EnstrophyMat: diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index de502b5e..5c982aa8 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -29,7 +29,9 @@ sindy_pi_flag = True except ImportError: sindy_pi_flag = False -from .optimizers import STLSQ, BaseOptimizer +from .optimizers import STLSQ +from .optimizers.base import _BaseOptimizer +from .optimizers.base import BaseOptimizer from .utils import AxesArray from .utils import comprehend_axes from .utils import concat_sample_axis @@ -43,7 +45,7 @@ class _BaseSINDy(BaseEstimator, ABC): feature_library: BaseFeatureLibrary - optimizer: BaseOptimizer + optimizer: _BaseOptimizer discrete_time: bool model: Pipeline feature_names: Optional[list[str]] diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index ef9a5af9..a0760fc2 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -61,7 +61,7 @@ def validate_input(x, t=T_DEFAULT): return x_new -def validate_no_reshape(x, t=T_DEFAULT): +def validate_no_reshape(x, t: float | np.ndarray = T_DEFAULT): """Check types and numerical sensibility of arguments. Args: @@ -72,7 +72,7 @@ def validate_no_reshape(x, t=T_DEFAULT): x as 2D array, with time dimension on first axis and coordinate index on second axis. """ - if not isinstance(x, np.ndarray): + if not hasattr(x, "shape"): raise TypeError("Input value must be array-like") check_array(x, ensure_2d=False, allow_nd=True) @@ -84,7 +84,7 @@ def validate_no_reshape(x, t=T_DEFAULT): if t <= 0: raise ValueError("t must be positive") # Only apply these tests if t is array-like - elif isinstance(t, np.ndarray): + elif hasattr(t, "shape"): if not len(t) == x.shape[-2]: raise ValueError("Length of t should match x.shape[-2].") if not np.all(t[:-1] < t[1:]): From 417bd5a4669fc8d4923e3d203f34df6413c3734a Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 16 Oct 2024 15:31:17 -0700 Subject: [PATCH 7/9] TST: Only test on CPU (faster) --- test/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/conftest.py b/test/conftest.py index 59fb61f7..ead11adf 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,7 @@ """ from pathlib import Path +import jax import numpy as np import pytest from scipy.integrate import solve_ivp @@ -20,6 +21,8 @@ from pysindy.utils.odes import lorenz from pysindy.utils.odes import lorenz_control +jax.config.update("jax_platform_name", "cpu") + def pytest_addoption(parser): parser.addoption( From f84df31d1c248ccd1b3cda320282c3e04585f49d Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 16 Oct 2024 15:51:18 -0700 Subject: [PATCH 8/9] TYP: Change | to 3.9 syntax, move stuff to _typing --- pysindy/_typing.py | 9 ++++++++- pysindy/optimizers/base.py | 10 +++++----- pysindy/optimizers/trapping_sr3.py | 22 +++++++++------------- pysindy/utils/base.py | 2 +- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/pysindy/_typing.py b/pysindy/_typing.py index 1ec0c446..d8eec0f7 100644 --- a/pysindy/_typing.py +++ b/pysindy/_typing.py @@ -4,4 +4,11 @@ # In python 3.12, use type statement # https://docs.python.org/3/reference/simple_stmts.html#the-type-statement NpFlt = np.floating[npt.NBitBase] -Float2D = np.ndarray[tuple[int, int], np.dtype[NpFlt]] +FloatDType = np.dtype[np.floating[npt.NBitBase]] +Int1D = np.ndarray[tuple[int], np.dtype[np.int_]] +Float1D = np.ndarray[tuple[int], FloatDType] +Float2D = np.ndarray[tuple[int, int], FloatDType] +Float3D = np.ndarray[tuple[int, int, int], FloatDType] +Float4D = np.ndarray[tuple[int, int, int, int], FloatDType] +Float5D = np.ndarray[tuple[int, int, int, int, int], FloatDType] +FloatND = npt.NDArray[NpFlt] diff --git a/pysindy/optimizers/base.py b/pysindy/optimizers/base.py index f4e8e9d3..2193952a 100644 --- a/pysindy/optimizers/base.py +++ b/pysindy/optimizers/base.py @@ -5,10 +5,10 @@ import warnings from typing import Callable from typing import NewType +from typing import Optional from typing import Tuple import numpy as np -from numpy.typing import NBitBase from scipy import sparse from sklearn.base import BaseEstimator from sklearn.linear_model import LinearRegression @@ -18,10 +18,10 @@ from sklearn.utils.validation import check_X_y from .._typing import Float2D +from .._typing import FloatDType from ..utils import AxesArray from ..utils import drop_nan_samples -AnyFloat = np.dtype[np.floating[NBitBase]] NFeat = NewType("NFeat", int) NTarget = NewType("NTarget", int) @@ -40,8 +40,8 @@ def _rescale_data(X, y, sample_weight): class _BaseOptimizer(BaseEstimator, abc.ABC): - coef_: np.ndarray[tuple[NTarget, NFeat], AnyFloat] - intercept_: np.ndarray[tuple[NTarget], AnyFloat] + coef_: np.ndarray[tuple[NTarget, NFeat], FloatDType] + intercept_: np.ndarray[tuple[NTarget], FloatDType] @property def complexity(self): @@ -101,7 +101,7 @@ class BaseOptimizer(LinearRegression, _BaseOptimizer): max_iter: int normalize_columns: bool - initial_guess: None | np.ndarray[tuple[NTarget, NFeat], AnyFloat] + initial_guess: Optional[np.ndarray[tuple[NTarget, NFeat], FloatDType]] copy_X: bool unbias: bool diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index fb68e1e6..a5712634 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -5,33 +5,29 @@ from itertools import repeat from math import comb from typing import cast -from typing import NewType from typing import Optional from typing import TypeVar from typing import Union import cvxpy as cp import numpy as np -from numpy.typing import NBitBase from numpy.typing import NDArray from sklearn.exceptions import ConvergenceWarning +from .._typing import Float1D +from .._typing import Float2D +from .._typing import Float3D +from .._typing import Float4D +from .._typing import Float5D +from .._typing import Int1D from ..feature_library.polynomial_library import n_poly_features from ..feature_library.polynomial_library import PolynomialLibrary from ..utils import reorder_constraints -from .base import AnyFloat +from .base import FloatDType from .base import NFeat from .base import NTarget from .constrained_sr3 import ConstrainedSR3 -Int1D = np.ndarray[tuple[int], np.dtype[np.int_]] -Float1D = np.ndarray[tuple[int], AnyFloat] -Float2D = np.ndarray[tuple[int, int], AnyFloat] -Float3D = np.ndarray[tuple[int, int, int], AnyFloat] -Float4D = np.ndarray[tuple[int, int, int, int], AnyFloat] -Float5D = np.ndarray[tuple[int, int, int, int, int], AnyFloat] -FloatND = NDArray[np.floating[NBitBase]] - class EnstrophyMat: """Pre-compute some useful factors of an enstrophy matrix @@ -601,7 +597,7 @@ def _solve_m_relax_and_split( self, trap_ctr: Float1D, prev_A: Float2D, - coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat], + coef_sparse: np.ndarray[tuple[NFeat, NTarget], FloatDType], ) -> tuple[Float1D, Float2D]: r"""Updates the trap center @@ -693,7 +689,7 @@ def _reduce(self, x, y): self.constraint_lhs = reorder_constraints( self.constraint_lhs, n_features, output_order="feature" ) - coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat] = self.coef_.T + coef_sparse: np.ndarray[tuple[NFeat, NTarget], FloatDType] = self.coef_.T # Print initial values for each term in the optimization if self.verbose: diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index a0760fc2..95cbf40a 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -61,7 +61,7 @@ def validate_input(x, t=T_DEFAULT): return x_new -def validate_no_reshape(x, t: float | np.ndarray = T_DEFAULT): +def validate_no_reshape(x, t: Union[float, np.ndarray, object] = T_DEFAULT): """Check types and numerical sensibility of arguments. Args: From 22329ddbeb65f464357c8f7b3793b055ee11d706 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:51:33 -0700 Subject: [PATCH 9/9] feat: Accomodate jax arrays This was thought to be easy, because in many cases jax arrays were an almost drop-in replacement for numpy arrays. However, they are far less amenable to subclassing. Why does this matter? The codebase gained a lot of readability with AxesArray allowing arrays to dynamically know what their axes meant, even after indexing changed their shape. However, extending AxesArray to dynamically subclass either numpy.ndarray or jax.Array is impossible - even a static subclass of the latter is impossible. Long term, we will need our own metadata type that carries around an array, it's type package (numpy or jax.numpy or cvxpy.numpy), its bidirectional mapping between axis index and axis meaning, and maybe even something from sympy. Short term, we should expose our general expectations for axis definitions as global constants. This is still error prone, as the constants are incorrect for arrays that have changed shape due to indexing, but will be far more readable than magic numbers. --- pyproject.toml | 2 +- pysindy/feature_library/base.py | 15 ++++++++++++--- pysindy/feature_library/polynomial_library.py | 5 +++-- pysindy/utils/_axis_conventions.py | 2 ++ 4 files changed, 18 insertions(+), 6 deletions(-) create mode 100644 pysindy/utils/_axis_conventions.py diff --git a/pyproject.toml b/pyproject.toml index 26e9d4ad..cd3603a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ ] readme = "README.rst" dependencies = [ + "jax>=0.4,<0.5", "scikit-learn>=1.1, !=1.5.0", "derivative>=0.6.2", "typing_extensions", @@ -64,7 +65,6 @@ cvxpy = [ ] sbr = [ "numpyro", - "jax", "arviz==0.17.1", "scipy<1.13.0" ] diff --git a/pysindy/feature_library/base.py b/pysindy/feature_library/base.py index 671386cb..8f24dc9a 100644 --- a/pysindy/feature_library/base.py +++ b/pysindy/feature_library/base.py @@ -8,6 +8,7 @@ from typing import Optional from typing import Sequence +import jax import numpy as np from scipy import sparse from sklearn.base import TransformerMixin @@ -144,20 +145,28 @@ def x_sequence_or_item(wrapped_func): @wraps(wrapped_func) def func(self, x, *args, **kwargs): if isinstance(x, Sequence): - xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] + if isinstance(x[0], jax.Array): + xs = x + else: + xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] result = wrapped_func(self, xs, *args, **kwargs) # if transform() is a normal "return x" if isinstance(result, Sequence) and isinstance(result[0], np.ndarray): return [AxesArray(xp, comprehend_axes(xp)) for xp in result] return result # e.g. fit() returns self else: - if not sparse.issparse(x): + if isinstance(x, jax.Array): + + def reconstructor(x): + return x + + elif not sparse.issparse(x) and isinstance(x, np.ndarray): x = AxesArray(x, comprehend_axes(x)) def reconstructor(x): return x - else: # sparse arrays + else: # sparse reconstructor = type(x) axes = comprehend_axes(x) wrap_axes(axes, x) diff --git a/pysindy/feature_library/polynomial_library.py b/pysindy/feature_library/polynomial_library.py index 72e357a7..496841e8 100644 --- a/pysindy/feature_library/polynomial_library.py +++ b/pysindy/feature_library/polynomial_library.py @@ -12,6 +12,7 @@ from ..utils import AxesArray from ..utils import comprehend_axes from ..utils import wrap_axes +from ..utils._axis_conventions import AX_COORD from .base import BaseFeatureLibrary from .base import x_sequence_or_item @@ -180,7 +181,7 @@ def fit(self, x_full: list[AxesArray], y=None): "Can't have include_interaction be False and interaction_only" " be True" ) - n_features = x_full[0].shape[x_full[0].ax_coord] + n_features = x_full[0].shape[AX_COORD] combinations = self._combinations( n_features, self.degree, @@ -217,7 +218,7 @@ def transform(self, x_full): axes = comprehend_axes(x) x = x.asformat("csc") wrap_axes(axes, x) - n_features = x.shape[x.ax_coord] + n_features = x.shape[AX_COORD] if n_features != self.n_features_in_: raise ValueError("x shape does not match training shape") diff --git a/pysindy/utils/_axis_conventions.py b/pysindy/utils/_axis_conventions.py new file mode 100644 index 00000000..98a7c582 --- /dev/null +++ b/pysindy/utils/_axis_conventions.py @@ -0,0 +1,2 @@ +AX_TIME = -2 +AX_COORD = -1