Skip to content

Commit

Permalink
remove squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
fhchl committed Jan 26, 2024
1 parent 9c77256 commit 1c8d855
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 29 deletions.
27 changes: 5 additions & 22 deletions dynax/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@
from .util import broadcast_right, dim2shape


try:
# TODO: remove when upgrading to diffrax > v0.2
DefaultAdjoint = dfx.NoAdjoint
except AttributeError:
DefaultAdjoint = dfx.DirectAdjoint


class AbstractEvolution(eqx.Module):
"""Abstract base-class for evolutions."""

def __call__(self, x0: ArrayLike, t: ArrayLike, u: ArrayLike, **kwargs):
def __call__(self, x0: ArrayLike, t: Array, u: Array, **kwargs):
raise NotImplementedError


Expand All @@ -45,11 +38,10 @@ class Flow(AbstractEvolution):
def __call__(
self,
x0: ArrayLike,
t: ArrayLike,
u: Optional[ArrayLike] = None,
t: Array,
u: Optional[Array] = None,
ufun: Optional[Callable[[float], Array]] = None,
ucoeffs: Optional[tuple[PyTree, PyTree, PyTree, PyTree]] = None,
squeeze: bool = True,
**diffeqsolve_kwargs,
) -> tuple[Array, Array]:
"""Solve initial value problem for state and output trajectories."""
Expand Down Expand Up @@ -90,7 +82,7 @@ def __call__(
stepsize_controller=self.step,
saveat=dfx.SaveAt(ts=t),
max_steps=50 * len(t),
adjoint=DefaultAdjoint(),
adjoint=dfx.DirectAdjoint(),
dt0=self.dt0 if self.dt0 is not None else t[1],
)
diffeqsolve_default_options |= diffeqsolve_kwargs
Expand All @@ -108,11 +100,6 @@ def __call__(
# Compute output
y = jax.vmap(self.system.output)(x, u, t)

# Remove singleton dimensions
if squeeze:
x = x.squeeze()
y = y.squeeze()

return x, y


Expand All @@ -127,7 +114,6 @@ def __call__(
t: Optional[Array] = None,
u: Optional[Array] = None,
num_steps: Optional[int] = None,
squeeze: bool = True,
):
"""Solve discrete map."""
x0 = jnp.asarray(x0)
Expand Down Expand Up @@ -164,8 +150,5 @@ def scan_fun(state, input):

# Compute output
y = jax.vmap(self.system.output)(x)
if squeeze:
# Remove singleton dimensions
x = x.squeeze()
y = y.squeeze()

return x, y
10 changes: 5 additions & 5 deletions dynax/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def input_output_linearize(
def feedbacklaw(x: Array, z: Array, v: float) -> float:
y_reldeg_ref = cAn.dot(z) + cAnm1b * v
y_reldeg = Lfnh(x)
return ((y_reldeg_ref - y_reldeg) / LgLfnm1h(x)).squeeze()
return (y_reldeg_ref - y_reldeg) / LgLfnm1h(x)

else:
msg = f"asymptotic must be of length {reldeg=} but, {len(asymptotic)=}"
Expand All @@ -113,16 +113,16 @@ def feedbacklaw(x: Array, z: Array, v: float) -> float:
y_reldeg = Lfnh(x)
ae0s = jnp.array(
[
ai * (Lfih(x) - cAi.dot(z))
ai * (cAi.dot(z) - Lfih(x))
for ai, Lfih, cAi in zip(alphas, Lfihs, cAis)
]
)
error = (y_reldeg_ref - y_reldeg - jnp.sum(ae0s))
error = (y_reldeg_ref - y_reldeg + jnp.sum(ae0s))
if reg is None:
return (error / LgLfnm1h(x)).squeeze()
return error / LgLfnm1h(x)
else:
l = LgLfnm1h(x)
return (error * l / (l + reg)).squeeze()
return error * l / (l + reg)

return feedbacklaw

Expand Down
6 changes: 4 additions & 2 deletions tests/test_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,18 @@ def test_discrete_forward_model():
B = jnp.array([[0], [1]])
C = jnp.array([[1, 0]])
D = jnp.zeros((1, 1))

# test just input
sys = LinearSystem(A, B, C, D)
model = Map(sys)
x, y = model(x0, u=u) # ours
scipy_sys = dlti(A, B, C, D)
_, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0)
npt.assert_allclose(scipy_y[:, 0], y, **tols)
npt.assert_allclose(scipy_y, y, **tols)
npt.assert_allclose(scipy_x, x, **tols)

# test input and time (results should be same)
x, y = model(x0, u=u, t=t)
scipy_t, scipy_y, scipy_x = dlsim(scipy_sys, u, x0=x0, t=t)
npt.assert_allclose(scipy_y[:, 0], y, **tols)
npt.assert_allclose(scipy_y, y, **tols)
npt.assert_allclose(scipy_x, x, **tols)

0 comments on commit 1c8d855

Please sign in to comment.