Skip to content

Commit

Permalink
cln: extract common fit functionality to base
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Sep 24, 2024
1 parent 02221e0 commit 5182f24
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.utils.validation import check_is_fitted
from typing_extensions import Self

from .differentiation import BaseDifferentiation
from .differentiation import FiniteDifference
from .feature_library import BaseFeatureLibrary
from .feature_library import PolynomialLibrary
Expand Down Expand Up @@ -54,6 +55,18 @@ class _BaseSINDy(BaseEstimator, ABC):
def fit(self, x, t, *args, **kwargs) -> Self:
...

def _fit_shape(self):
"""Assign shape attributes for the system that are used post-fit"""
self.n_features_in_ = self.feature_library.n_features_in_
self.n_output_features_ = self.feature_library.n_output_features_
if self.feature_names is None:
feature_names = []
for i in range(self.n_features_in_ - self.n_control_features_):
feature_names.append("x" + str(i))
for i in range(self.n_control_features_):
feature_names.append("u" + str(i))
self.feature_names = feature_names

def equations(self, precision: int = 3) -> list[str]:
"""
Get the right hand sides of the SINDy model equations.
Expand Down Expand Up @@ -242,12 +255,12 @@ class SINDy(_BaseSINDy):

def __init__(
self,
optimizer=None,
feature_library=None,
differentiation_method=None,
feature_names=None,
t_default=1,
discrete_time=False,
optimizer: Optional[BaseOptimizer] = None,
feature_library: Optional[BaseFeatureLibrary] = None,
differentiation_method: Optional[BaseDifferentiation] = None,
feature_names: Optional[list[str]] = None,
t_default: float = 1,
discrete_time: bool = False,
):
if optimizer is None:
optimizer = STLSQ()
Expand Down Expand Up @@ -350,17 +363,6 @@ def fit(
self.model = Pipeline(steps)
self.model.fit(x, x_dot)

self.n_features_in_ = self.feature_library.n_features_in_
self.n_output_features_ = self.feature_library.n_output_features_

if self.feature_names is None:
feature_names = []
for i in range(self.n_features_in_ - self.n_control_features_):
feature_names.append("x" + str(i))
for i in range(self.n_control_features_):
feature_names.append("u" + str(i))
self.feature_names = feature_names

return self

def predict(self, x, u=None):
Expand Down

0 comments on commit 5182f24

Please sign in to comment.