Skip to content

Commit

Permalink
feat: linearize discrete-time systems
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Jan 29, 2024
1 parent d2951b1 commit 289f8f9
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 21 deletions.
3 changes: 3 additions & 0 deletions dynax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from .evolution import AbstractEvolution as AbstractEvolution, Flow as Flow, Map as Map
from .interpolation import spline_it as spline_it
from .linearize import (
discrete_input_output_linearize as discrete_input_output_linearize,
discrete_relative_degree as discrete_relative_degree,
DiscreteLinearizingSystem as DiscreteLinearizingSystem,
input_output_linearize as input_output_linearize,
LinearizingSystem as LinearizingSystem,
relative_degree as relative_degree,
Expand Down
4 changes: 1 addition & 3 deletions dynax/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax._src.config import _validate_default_device
from jaxtyping import Array, ArrayLike, PyTree

from .interpolation import spline_it
Expand Down Expand Up @@ -149,6 +147,6 @@ def scan_fun(state, input):
_, x = jax.lax.scan(scan_fun, x0, inputs, length=num_steps)

# Compute output
y = jax.vmap(self.system.output)(x)
y = jax.vmap(self.system.output)(x, u, t)

return x, y
143 changes: 130 additions & 13 deletions dynax/linearize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Functions related to feedback linearization of nonlinear systems."""

from collections.abc import Callable
from functools import partial
from typing import Optional, Sequence

import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx
from jaxtyping import Array

from .derivative import lie_derivative
Expand Down Expand Up @@ -127,6 +129,116 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float:
return feedbacklaw


def prop(f: Callable[[Array, float], Array], n: int, x: Array, u: float) -> Array:
"""Propagates system n steps."""
# TODO: replace by lax.scan
if n == 0:
return x
return prop(f, n - 1, f(x, u), u)


def discrete_relative_degree(
sys: DynamicalSystem,
xs: Array,
us: Array,
max_reldeg=10,
output: Optional[int] = None,
):
"""Estimate relative degree of discrete-time system on region xs.
Source: Lee, Linearization of Nonlinear Control Systems (2022), Def. 7.7
"""
f = sys.vector_field
h = sys.output

y_depends_u = jax.grad(lambda n, x, u: h(prop(f, n, x, u)), 2)

for n in range(1, max_reldeg + 1):
res = jax.vmap(partial(y_depends_u, n))(xs, us)
if np.all(res == 0):
continue
elif np.all(res != 0):
return n
else:
raise RuntimeError("sys has ill defined relative degree.")
raise RuntimeError("Could not estmate relative degree. Increase max_reldeg.")


def discrete_input_output_linearize(
sys: DynamicalSystem,
reldeg: int,
ref: LinearSystem,
output: Optional[int] = None,
solver=None,
) -> Callable[[Array, Array, float, float], float]:
"""Construct input-output linearizing feedback law for a discrete-time system."""
# Lee 2022, Chap. 7.4
f = lambda x, u: sys.vector_field(x, u)
h = sys.output
A, b, c = ref.A, ref.B, ref.C
if sys.n_inputs != ref.n_inputs != 1:
raise ValueError("Systems must have single input.")
if output is None:
if not (sys.n_outputs == ref.n_outputs and sys.n_outputs in ["scalar", 1]):
raise ValueError("Systems must be single output and `output` is None.")
else:
_h = h
h = lambda x: _h(x)[output]
c = ref.C[output]

if solver is None:
solver = optx.Newton(rtol=1e-6, atol=1e-6)

cAn = c.dot(np.linalg.matrix_power(A, reldeg))
cAnm1b = c.dot(np.linalg.matrix_power(A, reldeg - 1)).dot(b)

def feedbacklaw(x: Array, z: Array, v: float, u_prev: float):
y_reldeg_ref = cAn.dot(z) + cAnm1b * v
fn = lambda u, args: (h(prop(f, reldeg, x, u)) - y_reldeg_ref).squeeze()
# Catch https://github.com/patrick-kidger/diffrax/issues/296
u = jax.lax.cond(
fn(u_prev, None) == 0,
lambda: u_prev,
lambda: optx.root_find(fn, solver, u_prev).value,
)
return u.squeeze()

return feedbacklaw


class DiscreteLinearizingSystem(DynamicalSystem):
r"""Dynamics computing linearizing feedback as output."""

sys: ControlAffine
refsys: LinearSystem
feedbacklaw: Callable

n_inputs = "scalar"

def __init__(self, sys, refsys, reldeg, linearizing_output=None):
if sys.n_inputs != "scalar":
raise ValueError("Only single input systems supported.")
self.sys = sys
self.refsys = refsys
self.n_states = self.sys.n_states + self.refsys.n_states + 1
self.feedbacklaw = discrete_input_output_linearize(
sys, reldeg, refsys, linearizing_output
)

def vector_field(self, x, u, t=None):
jax.debug.print("{}", x[0])
x, z, v_last = x[: self.sys.n_states], x[self.sys.n_states : -1], x[-1]
vn = self.feedbacklaw(x, z, u, v_last)
xn = self.sys.vector_field(x, vn)
zn = self.refsys.vector_field(z, u)
return jnp.concatenate((xn, zn, jnp.array([vn])))

def output(self, x, u=None, t=None):
v = x[-1]
return v


class LinearizingSystem(DynamicalSystem):
r"""Coupled ODE of nonlinear dynamics, linear reference and io linearizing law.
Expand All @@ -145,33 +257,38 @@ class LinearizingSystem(DynamicalSystem):

sys: ControlAffine
refsys: LinearSystem
feedbacklaw: Optional[Callable] = None

def __init__(self, sys, refsys, reldeg, feedbacklaw=None, linearizing_output=None):
if sys.n_inputs > 1:
raise ValueError("Only single input systems supported.")
feedbacklaw: Callable[[Array, Array, float], float]

n_inputs = "scalar"

def __init__(
self,
sys: ControlAffine,
refsys: LinearSystem,
reldeg: int,
feedbacklaw: Optional[Callable] = None,
linearizing_output: Optional[int] = None,
):
self.sys = sys
self.refsys = refsys
self.n_inputs = "scalar"
self.n_states = (
self.sys.n_states + self.refsys.n_states
) # FIXME: support "scalar"
self.feedbacklaw = feedbacklaw
if feedbacklaw is None:
if callable(feedbacklaw):
self.feedbacklaw = feedbacklaw
else:
self.feedbacklaw = input_output_linearize(
sys, reldeg, refsys, linearizing_output
)

def vector_field(self, x, u=None, t=None):
x, z = x[: self.sys.n_states], x[self.sys.n_states :]
if u is None:
u = 0.0
y = self.feedbacklaw(x, z, u)
dx = self.sys.vector_field(x, y)
dz = self.refsys.vector_field(z, u)
return jnp.concatenate((dx, dz))

def output(self, x, u=None, t=None):
def output(self, x, u, t=None):
x, z = x[: self.sys.n_states], x[self.sys.n_states :]
y = self.feedbacklaw(x, z, u)
return y
ur = self.feedbacklaw(x, z, u)
return ur
4 changes: 2 additions & 2 deletions dynax/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax import Array
from jax.typing import ArrayLike

from .util import dim2shape, ssmatrix
from .util import dim2shape


def _linearize(f, h, x0, u0, t0):
Expand Down Expand Up @@ -398,7 +398,7 @@ class DynamicStateFeedbackSystem(DynamicalSystem):
ẋ &= f_1(x, v(x, z, u), t) \\
ż &= f_2(z, r, t) \\
y &= h(x, u, t)
y &= h_1(x, u, t)
"""

Expand Down
53 changes: 52 additions & 1 deletion tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@

from dynax import (
ControlAffine,
discrete_relative_degree,
DiscreteLinearizingSystem,
DynamicalSystem,
DynamicStateFeedbackSystem,
Flow,
input_output_linearize,
LinearSystem,
Map,
relative_degree,
)
from dynax.example_models import NonlinearDrag, Sastry9_9
from dynax.linearize import input_output_linearize, is_controllable, relative_degree
from dynax.linearize import (
is_controllable,
)


tols = dict(rtol=1e-04, atol=1e-06)
Expand Down Expand Up @@ -45,6 +52,20 @@ def test_relative_degree():
assert relative_degree(sys, xs) == 1


def test_discrete_relative_degree():
xs = np.random.normal(size=(100, 2))
us = np.random.normal(size=(100, 1))

sys = SpringMassDamperWithOutput(out=0)
assert discrete_relative_degree(sys, xs, us) == 2

with npt.assert_raises(RuntimeError):
discrete_relative_degree(sys, xs, us, max_reldeg=1)

sys = SpringMassDamperWithOutput(out=1)
assert discrete_relative_degree(sys, xs, us) == 1


def test_is_controllable():
n = 3
A = np.diag(np.arange(n))
Expand Down Expand Up @@ -126,3 +147,33 @@ def test_input_output_linearize_multiple_outputs():
y_ref = Flow(ref)(np.zeros(sys.n_states), t, u)[1]
y = Flow(feedback_sys)(np.zeros(feedback_sys.n_states), t, u)[1]
npt.assert_allclose(y_ref[:, out_idx], y[:, out_idx], **tols)


class Lee7_4_5(DynamicalSystem):
n_states = 2
n_inputs = "scalar"

def vector_field(self, x, u, t=None):
x1, x2 = x
return 0.1 * jnp.array([x1 + x1**3 + x2, x2 + x2**3 + u])

def output(self, x, u=None, t=None):
return x[0]


def test_discrete_input_output_linearize():
sys = Lee7_4_5()
refsys = sys.linearize()
xs = np.random.normal(size=(100, 2))
us = np.random.normal(size=100)
reldeg = discrete_relative_degree(sys, xs, us)
assert reldeg == 2

feedback_sys = DiscreteLinearizingSystem(sys, refsys, reldeg)
t = np.linspace(0, 0.001, 10)
u = np.cos(t) * 0.1
_, v = Map(feedback_sys)(np.zeros(2 + 2 + 1), t, u)
_, y = Map(sys)(np.zeros(2), t, u)
_, y_ref = Map(refsys)(np.zeros(2), t, u)

npt.assert_allclose(y_ref, y, **tols)
4 changes: 2 additions & 2 deletions tests/test_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_discrete_forward_model():
x, y = model(x0, u=u) # ours
scipy_sys = dlti(A, B, C, D)
_, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0)
npt.assert_allclose(scipy_y[:, 0], y, **tols)
npt.assert_allclose(scipy_y, y, **tols)
npt.assert_allclose(scipy_x, x, **tols)
# test input and time (results should be same)
x, y = model(x0, u=u, t=t)
scipy_t, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0, t=t)
npt.assert_allclose(scipy_y[:, 0], y, **tols)
npt.assert_allclose(scipy_y, y, **tols)
npt.assert_allclose(scipy_x, x, **tols)

0 comments on commit 289f8f9

Please sign in to comment.