Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiation of Multidimensional Arrays in SINDyDerivative #476

Merged
merged 11 commits into from
Jul 15, 2024
17 changes: 13 additions & 4 deletions pysindy/differentiation/sindy_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Some default values used here may differ from those used in :doc:`derivative:index`.
"""
import numpy as np
from derivative import methods
from numpy import arange

Expand Down Expand Up @@ -33,7 +34,8 @@
for acceptable keywords.
"""

def __init__(self, save_smooth=True, **kwargs):
def __init__(self, axis=0, save_smooth=True, **kwargs):
self.axis = axis
self.kwargs = kwargs
self.save_smooth = save_smooth

Expand Down Expand Up @@ -76,9 +78,16 @@
differentiator = methods[self.kwargs["kind"]](
**{k: v for k, v in self.kwargs.items() if k != "kind"}
)
x_dot = differentiator.d(x, t, axis=0)
x = np.moveaxis(x, self.axis, 0)
x_shape = x.shape
flat_x = x.reshape((t.size, int(x.size / t.size)))
flat_x_dot = differentiator.d(flat_x, t, axis=0)
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
if self.save_smooth:
self.smoothed_x_ = differentiator.x(x, t, axis=0)
self.smoothed_x_ = differentiator.x(flat_x, t, axis=0)
else:
self.smoothed_x_ = x
self.smoothed_x_ = flat_x

Check warning on line 88 in pysindy/differentiation/sindy_derivative.py

View check run for this annotation

Codecov / codecov/patch

pysindy/differentiation/sindy_derivative.py#L88

Added line #L88 was not covered by tests
x_dot = flat_x_dot.reshape(x_shape)
self.smoothed_x_ = self.smoothed_x_.reshape(x_shape)
x_dot = np.moveaxis(x_dot, 0, self.axis)
self.smoothed_x_ = np.moveaxis(self.smoothed_x_, 0, self.axis)
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
return x_dot
9 changes: 9 additions & 0 deletions test/test_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,12 @@ def test_centered_difference_noaxis_vs_axis(data_2d_resolved_pde):
slow_differences_t,
atol=atol,
)


def test_multidimensional_differentiation():
X = np.random.random(size=(10, 100, 2))
t = np.arange(0, 10, 0.1)

X_dot = SINDyDerivative(kind="kalman", axis=-2)._differentiate(X, t)

assert X_dot.shape == X.shape
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
Loading