Skip to content

Commit

Permalink
remove x0 parameters from fit functions
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Feb 4, 2024
1 parent 7749cd7 commit 1e856a1
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ experiments
.coverage
htmlcov
build
docs/source/_build
_build
docs/generated
*.pytest_cache
.pytype
Expand Down
44 changes: 18 additions & 26 deletions dynax/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _compute_covariance(

def _least_squares(
f: Callable[[Array], Array],
x0: NDArray,
init_params: Array,
bounds: tuple[list, list],
reg_term: Optional[Callable[[Array], Array]] = None,
x_scale: bool = True,
Expand All @@ -105,7 +105,7 @@ def _least_squares(
# Add regularization term
_f = f
_reg_term = reg_term # https://github.com/python/mypy/issues/7268
f = lambda x: jnp.concatenate((_f(x), _reg_term(x)))
f = lambda params: jnp.concatenate((_f(params), _reg_term(params)))

if verbose_mse:
# Scale cost to mean-squared error
Expand All @@ -117,15 +117,17 @@ def f(params):

if x_scale:
# Scale parameters and bounds by initial values
norm = np.where(np.asarray(x0) != 0, np.abs(x0), 1)
x0 = x0 / norm
norm = np.where(np.asarray(init_params) != 0, np.abs(init_params), 1)
init_params = init_params / norm
___f = f
f = lambda x: ___f(x * norm)
f = lambda params: ___f(params * norm)
bounds = (np.array(bounds[0]) / norm, np.array(bounds[1]) / norm)

fun = MemoizeJac(jax.jit(lambda x: value_and_jacfwd(f, x)))
jac = fun.derivative
res = least_squares(fun, x0, bounds=bounds, jac=jac, x_scale="jac", **kwargs)
res = least_squares(
fun, init_params, bounds=bounds, jac=jac, x_scale="jac", **kwargs
)

if x_scale:
# Unscale parameters
Expand All @@ -139,8 +141,8 @@ def f(params):

if reg_term is not None:
# Remove regularization from residuals and Jacobian and cost
res.fun = res.fun[: -len(x0)]
res.jac = res.jac[: -len(x0)]
res.fun = res.fun[: -len(init_params)]
res.jac = res.jac[: -len(init_params)]
res.cost = np.sum(res.fun**2) / 2

return res
Expand All @@ -150,7 +152,6 @@ def fit_least_squares(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: Optional[ArrayLike] = None,
u: Optional[ArrayLike] = None,
batched: bool = False,
sigma: Optional[ArrayLike] = None,
Expand All @@ -169,11 +170,6 @@ def fit_least_squares(
t = jnp.asarray(t)
y = jnp.asarray(y)

if x0 is not None:
x0 = jnp.asarray(x0)
else:
x0 = model.system.initial_state

if batched:
# First axis holds experiments, second axis holds time.
std_y = np.std(y, axis=1, keepdims=True)
Expand Down Expand Up @@ -216,7 +212,7 @@ def residual_term(params):
# this can use pmap, if batch size is smaller than CPU cores
model = jax.vmap(model)
# FIXME: ucoeffs not supported for Map
_, pred_y = model(t=t, ucoeffs=ucoeffs, initial_state=x0)
_, pred_y = model(t=t, ucoeffs=ucoeffs)
res = (y - pred_y) * weight
return res.reshape(-1)

Expand Down Expand Up @@ -250,7 +246,6 @@ def fit_multiple_shooting(
model: AbstractEvolution,
t: ArrayLike,
y: ArrayLike,
x0: Optional[ArrayLike] = None,
u: Optional[Union[Callable[[float], Array], ArrayLike]] = None,
num_shots: int = 1,
continuity_penalty: float = 0.1,
Expand Down Expand Up @@ -278,11 +273,6 @@ def fit_multiple_shooting(
t = jnp.asarray(t)
y = jnp.asarray(y)

if x0 is not None:
x0 = jnp.asarray(x0)
else:
x0 = model.system.initial_state

if u is None:
msg = (
f"t, y must have same number of samples, but have shapes "
Expand All @@ -308,11 +298,13 @@ def fit_multiple_shooting(
t = t[:num_samples]
y = y[:num_samples]

n_states = len(model.system.initial_state)

# TODO: use numpy for everything that is not jitted
# Divide signals into segments.
ts = _moving_window(t, num_samples_per_segment, num_samples_per_segment - 1)
ys = _moving_window(y, num_samples_per_segment, num_samples_per_segment - 1)
x0s = np.zeros((num_shots - 1, len(x0)))
x0s = np.zeros((num_shots - 1, n_states))

ucoeffs = None
if u is not None:
Expand All @@ -329,8 +321,8 @@ def fit_multiple_shooting(
std_y = np.std(y, axis=0)
parameter_bounds = _get_bounds(model)
state_bounds = (
(num_shots - 1) * len(x0) * [-np.inf],
(num_shots - 1) * len(x0) * [np.inf],
(num_shots - 1) * n_states * [-np.inf],
(num_shots - 1) * n_states * [np.inf],
)
bounds = (
state_bounds[0] + parameter_bounds[0],
Expand All @@ -339,7 +331,7 @@ def fit_multiple_shooting(

def residuals(params):
x0s, model = unravel(params)
x0s = jnp.concatenate((x0[None], x0s), axis=0)
x0s = jnp.concatenate((model.system.initial_state[None], x0s), axis=0)
xs_pred, ys_pred = jax.vmap(model)(t=ts0, ucoeffs=ucoeffs, initial_state=x0s)
# output residual
res_y = ((ys - ys_pred) / std_y).reshape(-1)
Expand All @@ -353,7 +345,7 @@ def residuals(params):
res = _least_squares(residuals, init_params, bounds, x_scale=False, **kwargs)

x0s, res.result = unravel(res.x)
res.x0s = np.asarray(jnp.concatenate((x0[None], x0s), axis=0))
res.x0s = jnp.concatenate((res.result.system.initial_state[None], x0s), axis=0)
res.ts = np.asarray(ts)
res.ts0 = np.asarray(ts0)

Expand Down
15 changes: 9 additions & 6 deletions dynax/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ class Flow(AbstractEvolution):
"""Evolution for continous-time dynamical systems."""

solver: dfx.AbstractAdaptiveSolver = eqx.static_field(default_factory=dfx.Dopri5)
step: dfx.AbstractStepSizeController = eqx.static_field(
default_factory=dfx.ConstantStepSize
) # TODO: replace with adaptive step size
dt0: Optional[float] = eqx.static_field(default=None)
stepsize_controller: dfx.AbstractStepSizeController = eqx.static_field(
default_factory=lambda: dfx.ConstantStepSize()
)

def __call__(
self,
Expand Down Expand Up @@ -95,11 +94,15 @@ def __call__(
# Solve ODE.
diffeqsolve_default_options = dict(
solver=self.solver,
stepsize_controller=self.step,
stepsize_controller=self.stepsize_controller,
saveat=dfx.SaveAt(ts=t),
max_steps=50 * len(t), # completely arbitrary number of steps
adjoint=dfx.DirectAdjoint(),
dt0=self.dt0 if self.dt0 is not None else t[1],
dt0=(
t[1]
if isinstance(self.stepsize_controller, dfx.ConstantStepSize)
else None
),
)
diffeqsolve_default_options |= diffeqsolve_kwargs
vector_field = lambda t, x, self: self.system.vector_field(x, _ufun(t), t)
Expand Down
2 changes: 1 addition & 1 deletion dynax/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class LotkaVolterra(DynamicalSystem):
gamma: float = non_negative_field()
delta: float = non_negative_field()

initial_state = jnp.ones(2)
initial_state = jnp.ones(2) * 0.5
n_inputs = 0

def vector_field(self, x, u=None, t=None):
Expand Down
79 changes: 37 additions & 42 deletions examples/fit_ode.ipynb

Large diffs are not rendered by default.

76 changes: 38 additions & 38 deletions tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,27 @@
from dynax.example_models import LotkaVolterra, NonlinearDrag, SpringMassDamper


tols = dict(rtol=1e-05, atol=1e-05)
tols = dict(rtol=1e-02, atol=1e-04)


@pytest.mark.parametrize("outputs", [[0], [0, 1]])
def test_fit_least_squares(outputs):
# data
t = np.linspace(0, 2, 200)
t = np.linspace(0, 1, 100)
u = (
np.sin(1 * 2 * np.pi * t)
0.1 * np.sin(1 * 2 * np.pi * t)
+ np.sin(0.1 * 2 * np.pi * t)
+ np.sin(10 * 2 * np.pi * t)
)
x0 = jnp.array([1.0, 0.0])
true_model = Flow(NonlinearDrag(1.0, 2.0, 3.0, 4.0, outputs))
_, y_true = true_model(t, u, x0)
true_model = Flow(
NonlinearDrag(1.0, 2.0, 3.0, 4.0, outputs),
)
_, y_true = true_model(t, u)
# fit
init_model = Flow(NonlinearDrag(1.0, 1.0, 1.0, 1.0, outputs))
pred_model = fit_least_squares(init_model, t, y_true, x0, u).result
init_model = Flow(NonlinearDrag(1.0, 2.0, 3.0, 4.0, outputs))
pred_model = fit_least_squares(init_model, t, y_true, u, verbose=2).result
# check result
_, y_pred = pred_model(t, u, x0)
_, y_pred = pred_model(t, u)
npt.assert_allclose(y_pred, y_true, **tols)
npt.assert_allclose(
jax.tree_util.tree_flatten(pred_model)[0],
Expand All @@ -49,7 +50,7 @@ def test_fit_least_squares(outputs):

def test_fit_least_squares_on_batch():
# data
t = np.linspace(0, 2, 200)
t = np.linspace(0, 1, 100)
us = np.stack(
(
np.sin(1 * 2 * np.pi * t),
Expand All @@ -58,16 +59,18 @@ def test_fit_least_squares_on_batch():
),
axis=0,
)
x0 = np.array([1.0, 0.0])
x0s = np.repeat(x0[None], us.shape[0], axis=0)
ts = np.repeat(t[None], us.shape[0], axis=0)
true_model = Flow(NonlinearDrag(1.0, 2.0, 3.0, 4.0))
_, ys = jax.vmap(true_model)(ts, us, x0s)
true_model = Flow(
NonlinearDrag(1.0, 2.0, 3.0, 4.0),
)
_, ys = jax.vmap(true_model)(ts, us)
# fit
init_model = Flow(NonlinearDrag(1.0, 1.0, 1.0, 1.0))
pred_model = fit_least_squares(init_model, ts, ys, x0s, us, batched=True).result
init_model = Flow(
NonlinearDrag(1.0, 2.0, 3.0, 4.0),
)
pred_model = fit_least_squares(init_model, ts, ys, us, batched=True).result
# check result
_, ys_pred = jax.vmap(pred_model)(ts, us, x0s)
_, ys_pred = jax.vmap(pred_model)(ts, us)
npt.assert_allclose(ys_pred, ys, **tols)
npt.assert_allclose(
jax.tree_util.tree_flatten(pred_model)[0],
Expand All @@ -80,7 +83,9 @@ def test_can_compute_jacfwd_with_implicit_methods():
# don't get catched by https://github.com/patrick-kidger/diffrax/issues/135
t = jnp.linspace(0, 1, 10)
x0 = jnp.array([1.0, 0.0])
solver_opt = dict(solver=Kvaerno5(), step=PIDController(atol=1e-6, rtol=1e-3))
solver_opt = dict(
solver=Kvaerno5(), stepsize_controller=PIDController(atol=1e-6, rtol=1e-3)
)

def fun(m, r, k, x0=x0, solver_opt=solver_opt, t=t):
model = Flow(SpringMassDamper(m, r, k), **solver_opt)
Expand All @@ -94,19 +99,18 @@ def fun(m, r, k, x0=x0, solver_opt=solver_opt, t=t):
def test_fit_with_bounded_parameters():
# data
t = jnp.linspace(0, 1, 100)
x0 = jnp.array([0.5, 0.5])
solver_opt = dict(step=PIDController(rtol=1e-5, atol=1e-7))
solver_opt = dict(stepsize_controller=PIDController(rtol=1e-5, atol=1e-7))
true_model = Flow(
LotkaVolterra(alpha=2 / 3, beta=4 / 3, gamma=1.0, delta=1.0), **solver_opt
)
x_true, _ = true_model(t, initial_state=x0)
x_true, _ = true_model(t)
# fit
init_model = Flow(
LotkaVolterra(alpha=1.0, beta=1.0, gamma=1.5, delta=2.0), **solver_opt
)
pred_model = fit_least_squares(init_model, t, x_true, x0).result
pred_model = fit_least_squares(init_model, t, x_true).result
# check result
x_pred, _ = pred_model(t, initial_state=x0)
x_pred, _ = pred_model(t)
npt.assert_allclose(x_pred, x_true, **tols)
npt.assert_allclose(
jax.tree_util.tree_flatten(pred_model)[0],
Expand Down Expand Up @@ -134,7 +138,7 @@ def vector_field(self, x, u=None, t=None):

# data
t = jnp.linspace(0, 1, 100)
solver_opt = dict(step=PIDController(rtol=1e-5, atol=1e-7))
solver_opt = dict(stepsize_controller=PIDController(rtol=1e-5, atol=1e-7))
true_model = Flow(
LotkaVolterraBounded(
alpha=2 / 3, beta=4 / 3, delta_gamma=jnp.array([1.0, 1.0])
Expand All @@ -159,25 +163,23 @@ def vector_field(self, x, u=None, t=None):
@pytest.mark.parametrize("num_shots", [1, 2, 3])
def test_fit_multiple_shooting_with_input(num_shots):
# data
t = jnp.linspace(0, 10, 10000)
t = jnp.linspace(0, 1, 200)
u = jnp.sin(1 * 2 * np.pi * t)
x0 = jnp.array([1.0, 0.0])
true_model = Flow(SpringMassDamper(1.0, 2.0, 3.0))
x_true, _ = true_model(t, u, initial_state=x0)
x_true, _ = true_model(t, u)
# fit
init_model = Flow(SpringMassDamper(1.0, 1.0, 1.0))
pred_model = fit_multiple_shooting(
init_model,
t,
x_true,
x0,
u,
continuity_penalty=1,
num_shots=num_shots,
verbose=2,
).result
# check result
x_pred, _ = pred_model(t, u, initial_state=x0)
x_pred, _ = pred_model(t, u)
npt.assert_allclose(x_pred, x_true, **tols)
npt.assert_allclose(
jax.tree_util.tree_flatten(pred_model)[0],
Expand All @@ -189,22 +191,21 @@ def test_fit_multiple_shooting_with_input(num_shots):
@pytest.mark.parametrize("num_shots", [1, 2, 3])
def test_fit_multiple_shooting_without_input(num_shots):
# data
t = jnp.linspace(0, 1, 1000)
x0 = jnp.array([0.5, 0.5])
solver_opt = dict(step=PIDController(rtol=1e-3, atol=1e-6))
t = jnp.linspace(0, 1, 200)
solver_opt = dict(stepsize_controller=PIDController(rtol=1e-3, atol=1e-6))
true_model = Flow(
LotkaVolterra(alpha=2 / 3, beta=4 / 3, gamma=1.0, delta=1.0), **solver_opt
)
x_true, _ = true_model(t, initial_state=x0)
x_true, _ = true_model(t)
# fit
init_model = Flow(
LotkaVolterra(alpha=1.0, beta=1.0, gamma=1.5, delta=2.0), **solver_opt
)
pred_model = fit_multiple_shooting(
init_model, t, x_true, x0, num_shots=num_shots, continuity_penalty=1
init_model, t, x_true, num_shots=num_shots, continuity_penalty=1
).result
# check result
x_pred, _ = pred_model(t, initial_state=x0)
x_pred, _ = pred_model(t)
npt.assert_allclose(x_pred, x_true, atol=1e-3, rtol=1e-3)
npt.assert_allclose(
jax.tree_util.tree_flatten(pred_model)[0],
Expand All @@ -228,15 +229,14 @@ def test_csd_matching():
np.random.seed(123)
# model
sys = SpringMassDamper(1.0, 1.0, 1.0)
model = Flow(sys, step=PIDController(rtol=1e-4, atol=1e-6))
x0 = np.zeros(jnp.shape(sys.initial_state))
model = Flow(sys, stepsize_controller=PIDController(rtol=1e-4, atol=1e-6))
# input
duration = 1000
sr = 50
t = np.arange(int(duration * sr)) / sr
u = np.random.normal(size=len(t))
# output
_, y = model(t, u, initial_state=x0)
_, y = model(t, u)
# fit
init_sys = SpringMassDamper(1.0, 1.0, 1.0)
fitted_sys = fit_csd_matching(init_sys, u, y, sr, nperseg=1024, verbose=1).result
Expand Down
Loading

0 comments on commit 1e856a1

Please sign in to comment.