From 289f8f9299466ff46269600763e072a9db342bb5 Mon Sep 17 00:00:00 2001 From: fhchl Date: Wed, 9 Aug 2023 17:13:41 +0200 Subject: [PATCH] feat: linearize discrete-time systems --- dynax/__init__.py | 3 + dynax/evolution.py | 4 +- dynax/linearize.py | 143 ++++++++++++++++++++++++++++++++++++---- dynax/system.py | 4 +- tests/test_linearize.py | 53 ++++++++++++++- tests/test_systems.py | 4 +- 6 files changed, 190 insertions(+), 21 deletions(-) diff --git a/dynax/__init__.py b/dynax/__init__.py index 2266dbc..641d3cb 100644 --- a/dynax/__init__.py +++ b/dynax/__init__.py @@ -12,6 +12,9 @@ from .evolution import AbstractEvolution as AbstractEvolution, Flow as Flow, Map as Map from .interpolation import spline_it as spline_it from .linearize import ( + discrete_input_output_linearize as discrete_input_output_linearize, + discrete_relative_degree as discrete_relative_degree, + DiscreteLinearizingSystem as DiscreteLinearizingSystem, input_output_linearize as input_output_linearize, LinearizingSystem as LinearizingSystem, relative_degree as relative_degree, diff --git a/dynax/evolution.py b/dynax/evolution.py index f823b19..31d61ee 100644 --- a/dynax/evolution.py +++ b/dynax/evolution.py @@ -4,8 +4,6 @@ import equinox as eqx import jax import jax.numpy as jnp -import numpy as np -from jax._src.config import _validate_default_device from jaxtyping import Array, ArrayLike, PyTree from .interpolation import spline_it @@ -149,6 +147,6 @@ def scan_fun(state, input): _, x = jax.lax.scan(scan_fun, x0, inputs, length=num_steps) # Compute output - y = jax.vmap(self.system.output)(x) + y = jax.vmap(self.system.output)(x, u, t) return x, y diff --git a/dynax/linearize.py b/dynax/linearize.py index 771815f..b2ff67b 100644 --- a/dynax/linearize.py +++ b/dynax/linearize.py @@ -1,11 +1,13 @@ """Functions related to feedback linearization of nonlinear systems.""" from collections.abc import Callable +from functools import partial from typing import Optional, Sequence import jax import jax.numpy as jnp import numpy as np +import optimistix as optx from jaxtyping import Array from .derivative import lie_derivative @@ -127,6 +129,116 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float: return feedbacklaw +def prop(f: Callable[[Array, float], Array], n: int, x: Array, u: float) -> Array: + """Propagates system n steps.""" + # TODO: replace by lax.scan + if n == 0: + return x + return prop(f, n - 1, f(x, u), u) + + +def discrete_relative_degree( + sys: DynamicalSystem, + xs: Array, + us: Array, + max_reldeg=10, + output: Optional[int] = None, +): + """Estimate relative degree of discrete-time system on region xs. + + Source: Lee, Linearization of Nonlinear Control Systems (2022), Def. 7.7 + + """ + f = sys.vector_field + h = sys.output + + y_depends_u = jax.grad(lambda n, x, u: h(prop(f, n, x, u)), 2) + + for n in range(1, max_reldeg + 1): + res = jax.vmap(partial(y_depends_u, n))(xs, us) + if np.all(res == 0): + continue + elif np.all(res != 0): + return n + else: + raise RuntimeError("sys has ill defined relative degree.") + raise RuntimeError("Could not estmate relative degree. Increase max_reldeg.") + + +def discrete_input_output_linearize( + sys: DynamicalSystem, + reldeg: int, + ref: LinearSystem, + output: Optional[int] = None, + solver=None, +) -> Callable[[Array, Array, float, float], float]: + """Construct input-output linearizing feedback law for a discrete-time system.""" + # Lee 2022, Chap. 7.4 + f = lambda x, u: sys.vector_field(x, u) + h = sys.output + A, b, c = ref.A, ref.B, ref.C + if sys.n_inputs != ref.n_inputs != 1: + raise ValueError("Systems must have single input.") + if output is None: + if not (sys.n_outputs == ref.n_outputs and sys.n_outputs in ["scalar", 1]): + raise ValueError("Systems must be single output and `output` is None.") + else: + _h = h + h = lambda x: _h(x)[output] + c = ref.C[output] + + if solver is None: + solver = optx.Newton(rtol=1e-6, atol=1e-6) + + cAn = c.dot(np.linalg.matrix_power(A, reldeg)) + cAnm1b = c.dot(np.linalg.matrix_power(A, reldeg - 1)).dot(b) + + def feedbacklaw(x: Array, z: Array, v: float, u_prev: float): + y_reldeg_ref = cAn.dot(z) + cAnm1b * v + fn = lambda u, args: (h(prop(f, reldeg, x, u)) - y_reldeg_ref).squeeze() + # Catch https://github.com/patrick-kidger/diffrax/issues/296 + u = jax.lax.cond( + fn(u_prev, None) == 0, + lambda: u_prev, + lambda: optx.root_find(fn, solver, u_prev).value, + ) + return u.squeeze() + + return feedbacklaw + + +class DiscreteLinearizingSystem(DynamicalSystem): + r"""Dynamics computing linearizing feedback as output.""" + + sys: ControlAffine + refsys: LinearSystem + feedbacklaw: Callable + + n_inputs = "scalar" + + def __init__(self, sys, refsys, reldeg, linearizing_output=None): + if sys.n_inputs != "scalar": + raise ValueError("Only single input systems supported.") + self.sys = sys + self.refsys = refsys + self.n_states = self.sys.n_states + self.refsys.n_states + 1 + self.feedbacklaw = discrete_input_output_linearize( + sys, reldeg, refsys, linearizing_output + ) + + def vector_field(self, x, u, t=None): + jax.debug.print("{}", x[0]) + x, z, v_last = x[: self.sys.n_states], x[self.sys.n_states : -1], x[-1] + vn = self.feedbacklaw(x, z, u, v_last) + xn = self.sys.vector_field(x, vn) + zn = self.refsys.vector_field(z, u) + return jnp.concatenate((xn, zn, jnp.array([vn]))) + + def output(self, x, u=None, t=None): + v = x[-1] + return v + + class LinearizingSystem(DynamicalSystem): r"""Coupled ODE of nonlinear dynamics, linear reference and io linearizing law. @@ -145,33 +257,38 @@ class LinearizingSystem(DynamicalSystem): sys: ControlAffine refsys: LinearSystem - feedbacklaw: Optional[Callable] = None - - def __init__(self, sys, refsys, reldeg, feedbacklaw=None, linearizing_output=None): - if sys.n_inputs > 1: - raise ValueError("Only single input systems supported.") + feedbacklaw: Callable[[Array, Array, float], float] + + n_inputs = "scalar" + + def __init__( + self, + sys: ControlAffine, + refsys: LinearSystem, + reldeg: int, + feedbacklaw: Optional[Callable] = None, + linearizing_output: Optional[int] = None, + ): self.sys = sys self.refsys = refsys - self.n_inputs = "scalar" self.n_states = ( self.sys.n_states + self.refsys.n_states ) # FIXME: support "scalar" - self.feedbacklaw = feedbacklaw - if feedbacklaw is None: + if callable(feedbacklaw): + self.feedbacklaw = feedbacklaw + else: self.feedbacklaw = input_output_linearize( sys, reldeg, refsys, linearizing_output ) def vector_field(self, x, u=None, t=None): x, z = x[: self.sys.n_states], x[self.sys.n_states :] - if u is None: - u = 0.0 y = self.feedbacklaw(x, z, u) dx = self.sys.vector_field(x, y) dz = self.refsys.vector_field(z, u) return jnp.concatenate((dx, dz)) - def output(self, x, u=None, t=None): + def output(self, x, u, t=None): x, z = x[: self.sys.n_states], x[self.sys.n_states :] - y = self.feedbacklaw(x, z, u) - return y + ur = self.feedbacklaw(x, z, u) + return ur diff --git a/dynax/system.py b/dynax/system.py index 52347f6..416d2fd 100644 --- a/dynax/system.py +++ b/dynax/system.py @@ -11,7 +11,7 @@ from jax import Array from jax.typing import ArrayLike -from .util import dim2shape, ssmatrix +from .util import dim2shape def _linearize(f, h, x0, u0, t0): @@ -398,7 +398,7 @@ class DynamicStateFeedbackSystem(DynamicalSystem): ẋ &= f_1(x, v(x, z, u), t) \\ ż &= f_2(z, r, t) \\ - y &= h(x, u, t) + y &= h_1(x, u, t) """ diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 342081e..15aea50 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -4,13 +4,20 @@ from dynax import ( ControlAffine, + discrete_relative_degree, + DiscreteLinearizingSystem, DynamicalSystem, DynamicStateFeedbackSystem, Flow, + input_output_linearize, LinearSystem, + Map, + relative_degree, ) from dynax.example_models import NonlinearDrag, Sastry9_9 -from dynax.linearize import input_output_linearize, is_controllable, relative_degree +from dynax.linearize import ( + is_controllable, +) tols = dict(rtol=1e-04, atol=1e-06) @@ -45,6 +52,20 @@ def test_relative_degree(): assert relative_degree(sys, xs) == 1 +def test_discrete_relative_degree(): + xs = np.random.normal(size=(100, 2)) + us = np.random.normal(size=(100, 1)) + + sys = SpringMassDamperWithOutput(out=0) + assert discrete_relative_degree(sys, xs, us) == 2 + + with npt.assert_raises(RuntimeError): + discrete_relative_degree(sys, xs, us, max_reldeg=1) + + sys = SpringMassDamperWithOutput(out=1) + assert discrete_relative_degree(sys, xs, us) == 1 + + def test_is_controllable(): n = 3 A = np.diag(np.arange(n)) @@ -126,3 +147,33 @@ def test_input_output_linearize_multiple_outputs(): y_ref = Flow(ref)(np.zeros(sys.n_states), t, u)[1] y = Flow(feedback_sys)(np.zeros(feedback_sys.n_states), t, u)[1] npt.assert_allclose(y_ref[:, out_idx], y[:, out_idx], **tols) + + +class Lee7_4_5(DynamicalSystem): + n_states = 2 + n_inputs = "scalar" + + def vector_field(self, x, u, t=None): + x1, x2 = x + return 0.1 * jnp.array([x1 + x1**3 + x2, x2 + x2**3 + u]) + + def output(self, x, u=None, t=None): + return x[0] + + +def test_discrete_input_output_linearize(): + sys = Lee7_4_5() + refsys = sys.linearize() + xs = np.random.normal(size=(100, 2)) + us = np.random.normal(size=100) + reldeg = discrete_relative_degree(sys, xs, us) + assert reldeg == 2 + + feedback_sys = DiscreteLinearizingSystem(sys, refsys, reldeg) + t = np.linspace(0, 0.001, 10) + u = np.cos(t) * 0.1 + _, v = Map(feedback_sys)(np.zeros(2 + 2 + 1), t, u) + _, y = Map(sys)(np.zeros(2), t, u) + _, y_ref = Map(refsys)(np.zeros(2), t, u) + + npt.assert_allclose(y_ref, y, **tols) diff --git a/tests/test_systems.py b/tests/test_systems.py index a064d25..cf4d407 100644 --- a/tests/test_systems.py +++ b/tests/test_systems.py @@ -137,10 +137,10 @@ def test_discrete_forward_model(): x, y = model(x0, u=u) # ours scipy_sys = dlti(A, B, C, D) _, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0) - npt.assert_allclose(scipy_y[:, 0], y, **tols) + npt.assert_allclose(scipy_y, y, **tols) npt.assert_allclose(scipy_x, x, **tols) # test input and time (results should be same) x, y = model(x0, u=u, t=t) scipy_t, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0, t=t) - npt.assert_allclose(scipy_y[:, 0], y, **tols) + npt.assert_allclose(scipy_y, y, **tols) npt.assert_allclose(scipy_x, x, **tols)