Skip to content

Commit

Permalink
tests: estimation of initial state
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Feb 19, 2024
1 parent 59dbf14 commit 4a3812f
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -13,6 +14,7 @@
fit_least_squares,
fit_multiple_shooting,
Flow,
free_field,
non_negative_field,
transfer_function,
)
Expand Down Expand Up @@ -247,3 +249,40 @@ def test_csd_matching():
rtol=1e-1,
atol=1e-1,
)


def test_estimate_initial_state():
class NonlinearDragFreeInitialState(NonlinearDrag):
initial_state: Array = free_field(init=False)

def __post_init__(self):
self.initial_state = jnp.zeros(2)

# data
t = np.linspace(0, 2, 200)
u = (
np.sin(1 * 2 * np.pi * t)
+ np.sin(0.1 * 2 * np.pi * t)
+ np.sin(10 * 2 * np.pi * t)
)

# True model has nonzero initial state
true_initial_state = jnp.array([1.0, 0.5])
true_model = Flow(NonlinearDragFreeInitialState(1.0, 2.0, 3.0, 4.0, outputs=[0, 1]))
true_model = eqx.tree_at(
lambda t: t.system.initial_state, true_model, true_initial_state
)
_, y_true = true_model(t, u, true_initial_state)

# fit
init_model = Flow(NonlinearDragFreeInitialState(1.0, 1.0, 1.0, 1.0, outputs=[0, 1]))
pred_model = fit_least_squares(init_model, t, y_true, u=u).result

# check result
_, y_pred = pred_model(t, u)
npt.assert_allclose(y_pred, y_true, **tols)
npt.assert_allclose(
pred_model.system.initial_state,
true_initial_state,
**tols,
)

0 comments on commit 4a3812f

Please sign in to comment.