Skip to content

Commit

Permalink
FIX: truncated DMD; dimensions in DMDwC + ADD: reconstruct full A and B
Browse files Browse the repository at this point in the history
  • Loading branch information
MarekWadinger committed Mar 13, 2024
1 parent 19456d3 commit c5b97f3
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 59 deletions.
170 changes: 124 additions & 46 deletions river/decomposition/odmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,11 @@ def modes(self) -> np.ndarray:
"""Reconstruct high dimensional DMD modes"""
_, Phi = self.eig
if self.r < self.m:
return self._svd._U @ np.diag(self._svd._S) @ Phi
# Sign of eigenvectors and singular vectors may change based on underlying algorithm initialization
# TODO: verify sign of singular values
return np.abs(self._svd._U) @ np.diag(self._svd._S) @ np.abs(Phi)
else:
return Phi
return np.abs(Phi)

@property
def xi(self) -> np.ndarray:
Expand Down Expand Up @@ -237,14 +239,25 @@ def _truncate_w_svd(
):
U_prev = self._svd._U
if svd_modify == "update":
self._svd.update(x.reshape(1, -1))
self._svd.update(x)
elif svd_modify == "revert":
self._svd.revert(x.reshape(1, -1))
self._svd.revert(x)
_U = self._svd._U
_UU = _U.T @ U_prev
x = _U.T @ x
y = _U.T @ y
self.A = _UU @ self.A @ _UU.T
x = x @ _U
# p != self.m and p == self.A.shape[0] in case of DMDwC
p = self.A.shape[0]
y = y @ _U[: y.shape[1], :p]
# Check if A is square
if self.A.shape[0] == self.A.shape[1]:
self.A = _UU @ self.A @ _UU.T
# If A is not square, it is called by DMDwC
else:
_UUp = _UU[:p, :p]
_UUq = _UU[p:, p:]
self.A = np.hstack(
(_UUp @ self.A[:, :p] @ _UUp.T, _UUp @ self.A[:, p:] @ _UUq.T)
)
self._P = np.linalg.inv(_UU @ np.linalg.inv(self._P) @ _UU.T) / self.w

return x, y
Expand Down Expand Up @@ -329,7 +342,6 @@ def update(
epsilon = 1e-15
alpha = 1.0 / epsilon
self._P = alpha * np.identity(self.r)

if self.r < self.m:
x_, y_ = self._truncate_w_svd(x_, y_, svd_modify="update")

Expand Down Expand Up @@ -383,12 +395,12 @@ def revert(
if isinstance(y, dict):
y = np.array(list(y.values()))
if len(y.shape) == 1:
y = y.reshape(1, -1)
y_ = y.reshape(1, -1)
else:
y_ = y

if self.r < self.m:
x_, y_ = self._truncate_w_svd(x_, y_, svd_modify=None)
x_, y_ = self._truncate_w_svd(x_, y_, svd_modify="revert")

# Apply exponential weighting factor
if self.exponential_weighting:
Expand Down Expand Up @@ -469,7 +481,8 @@ def learn_many(
self.m = X.shape[1]
if self.r == 0:
self.r = self.m
assert n >= self.m and np.linalg.matrix_rank(X) == self.m

assert np.linalg.matrix_rank(X) >= self.m
# Exponential weighting factor - older snapshots are weighted less
if self.exponential_weighting:
weights = (np.sqrt(self.w) ** np.arange(n - 1, -1, -1))[
Expand All @@ -478,14 +491,33 @@ def learn_many(
else:
weights = np.ones((n, 1))
Xqhat, Yqhat = weights * X, weights * Y
# Perform truncated DMD
if self.r < self.m:
self._svd.learn_many(Xqhat)
_U, _S, _V = self._svd._U, self._svd._S, self._svd._V
self.A = _U.T @ Yqhat.T @ _V.T @ np.diag(1 / _S)

_m = Yqhat.shape[1]
_l = self.m - _m

# DMDwC, A = U.T @ K @ U; B = U.T @ K [Proctor (2016)]
if _l != 0:
_UU = _U.T @ np.row_stack([_U[:_m], np.eye(_l, self.r)])
# DMD, A = U.T @ K @ U
else:
_UU = np.eye(self.r)

# TODO: Verify if equivalent to Proctor (2016). They compute U_hat from SVD(Y), we select the first r columns of U
self.A = (
_U.T[:, : Yqhat.shape[1]]
@ Yqhat.T
@ _V.T
@ np.diag(1 / _S)
) @ _UU
self._P = np.linalg.inv(_U.T @ Xqhat.T @ Xqhat @ _U) / self.w
# Perform exact DMD
else:
self.A = Yqhat.T.dot(np.linalg.pinv(Xqhat.T))
self._P = np.linalg.inv(Xqhat.T.dot(Xqhat)) / self.w
self._P = np.linalg.inv(Xqhat.T @ Xqhat) / self.w

# Store the last p snapshots for xi computation
self._Y = Yqhat
Expand All @@ -507,10 +539,15 @@ def predict_one(self, x: dict | np.ndarray) -> np.ndarray:
Returns:
np.ndarray: The predicted next state.
"""
# Map A back to original space
if self.r < self.m:
A = self._svd._U @ self.A @ self._svd._U.T
else:
A = self.A
mat = np.zeros((2, self.m))
mat[0, :] = x if isinstance(x, np.ndarray) else list(x.values())
for s in range(1, 2):
mat[s, :] = (self.A @ mat[s - 1, :]).real
mat[s, :] = (A @ mat[s - 1, :]).real
return mat[-1, :]

def predict_many(self, x: dict | np.ndarray, forecast: int) -> np.ndarray:
Expand Down Expand Up @@ -606,7 +643,8 @@ class OnlineDMDwC(OnlineDMD):
Args:
B: control matrix, size n by m. If None, the control matrix will be
identified from the snapshots. Defaults to None.
r: number of modes to keep. If 0 (default), all modes are kept.
p: truncation of states. If 0 (default), compute exact DMD.
q: truncation of control. If 0 (default), compute exact DMD.
w: weighting factor in (0,1]. Smaller value allows more adpative
learning, but too small weighting may result in model identification
instability (relies only on limited recent snapshots).
Expand All @@ -619,7 +657,7 @@ class OnlineDMDwC(OnlineDMD):
seed: random seed for reproducibility (initialize A with random values)
Attributes:
m: state dimension x(t) as in z(t) = f(z(t-1)) or y(t) = f(t, x(t))
m: augumented state dimension. if B is None, m = x.shape[1], else m = x.shape[1] + u.shape[1]
n_seen: number of seen samples (read-only), reverted if windowed
A: DMD matrix, size n by n
_P: inverse of covariance matrix of X
Expand All @@ -645,7 +683,7 @@ class OnlineDMDwC(OnlineDMD):
>>> df = pd.DataFrame({"w1": w1[:-1], "w2": w2[:-1]})
>>> U = pd.DataFrame({"u": u_[:-2]})
>>> model = OnlineDMDwC(r=2, w=0.1, initialize=0)
>>> model = OnlineDMDwC(p=2, q=1, w=0.1, initialize=4)
>>> X, Y = df.iloc[:-1], df.shift(-1).iloc[:-1]
>>> for (_, x), (_, y), (_, u) in zip(X.iterrows(), Y.iterrows(), U.iterrows()):
Expand All @@ -661,7 +699,7 @@ class OnlineDMDwC(OnlineDMD):
Supports mini-batch learning:
>>> from river.utils import Rolling
>>> model = Rolling(OnlineDMDwC(r=2, w=1.0), 10)
>>> model = Rolling(OnlineDMDwC(p=2, q=1, w=1.0), 10)
>>> X, Y = df.iloc[:-1], df.shift(-1).iloc[:-1]
>>> for (_, x), (_, y), (_, u) in zip(X.iterrows(), Y.iterrows(), U.iterrows()):
Expand Down Expand Up @@ -707,23 +745,45 @@ class OnlineDMDwC(OnlineDMD):
def __init__(
self,
B: np.ndarray | None = None,
r: int = 0,
p: int = 0,
q: int = 0, # TODO: fix case when q is 0
w: float = 1.0,
initialize: int = 1,
exponential_weighting: bool = False,
seed: int | None = None,
) -> None:
super().__init__(
r,
p + q,
w,
initialize,
exponential_weighting,
seed,
)
self.p = p
self.q = q
self.B = B
self.known_B = B is not None
self.l: int

def _reconstruct_AB(self):
# self.m stores augumented state dimension
_m = self.m - self.l if not self.known_B else self.m
if self.r < self.m:
A = (
self._svd._U[:_m, : self.p]
@ self.A
@ self._svd._U[:_m, : self.p].T
)
B = (
self._svd._U[:_m, : self.p]
@ self.B
@ self._svd._U[-self.q :, -self.l :]
)
else:
A = self.A
B = self.B
return A, B

def _update_many(
self,
X: np.ndarray | pd.DataFrame,
Expand Down Expand Up @@ -787,26 +847,33 @@ def learn_many( # type: ignore # TODO: fix override OnlineDMD.learn_many
Y = Y - self.B @ U
else:
X = np.hstack((X, U))
if not self.known_B and self.B is not None:
self.A = np.hstack((self.A, self.B))
if self.B is not None: # If learn_many is not called first
self.A = np.hstack((self.A, self.B))

self.l = U.shape[1]
super().learn_many(X, Y)
self.m = self.m - self.l # PATCH: overwrite change of parent

if not self.known_B:
self.B = self.A[:, -self.l :]
self.A = self.A[:, : -self.l]
self.B = self.A[: self.p, -self.l :]
self.A = self.A[: self.p, : -self.l]

def _init_update(self):
if not self.known_B and self.initialize < self.m + self.l:
def _init_update(self) -> None:
if self.initialize < self.m:
warnings.warn(
f"Initialization is under-constrained. Changed initialize to {self.m + self.l}."
f"Initialization is under-constrained. Changed initialize to {self.m}."
)
self.initialize = self.m + self.l
# TODO: find out whether should be set in init or here
self.B = np.random.randn(self.m, self.l)
self.initialize = self.m
if self.p == 0:
self.p = self.m
if self.q == 0:
self.q = self.l

self.A = np.random.randn(self.p, self.p)
self.B = np.random.randn(self.p, self.q)
self._U_init = np.zeros((self.initialize, self.l))
super()._init_update()
self._X_init = np.empty((self.initialize, self.m - self.l))
self._Y_init = np.empty((self.initialize, self.m - self.l))
self._Y = np.empty((0, self.m - self.l))

def update( # type: ignore # TODO: fix override OnlineDMD.update
self,
Expand Down Expand Up @@ -836,16 +903,19 @@ def update( # type: ignore # TODO: fix override OnlineDMD.update
super().update(x, y)
else:
if self.n_seen == 0:
self.m = len(x)
self.m = len(x) if self.known_B else len(x) + len(u)
self.l = len(u)
self._init_update()

if bool(self.initialize) and self.n_seen <= self.initialize - 1:
if self.initialize and self.n_seen <= self.initialize - 1:
# Accumulate buffer of past snapshots for initialization
self._X_init[self.n_seen, :] = x
self._Y_init[self.n_seen, :] = y
self._U_init[self.n_seen, :] = u
# Run the initialization after collecting enough snapshots
if self.n_seen == self.initialize - 1:
self.learn_many(self._X_init, self._Y_init, self._U_init)
# Subtract the number of seen samples to avoid doubling
self.n_seen -= self._X_init.shape[1]

else:
Expand All @@ -858,9 +928,10 @@ def update( # type: ignore # TODO: fix override OnlineDMD.update

super().update(x, y)

if not self.known_B:
self.B = self.A[:, -self.l :]
self.A = self.A[:, : -self.l]
# In case that learn_many was called, A is already square
if self.A.shape[0] < self.A.shape[1]:
self.B = self.A[: self.p, -self.q :]
self.A = self.A[: self.p, : -self.q]

self.n_seen += 1

Expand Down Expand Up @@ -905,8 +976,8 @@ def revert( # type: ignore # TODO: fix override OnlineDMD.revert
super().revert(x, y)

if not self.known_B:
self.B = self.A[:, -self.l :]
self.A = self.A[:, : -self.l]
self.B = self.A[: self.p, -self.l :]
self.A = self.A[: self.p, : -self.l]

def predict_one( # type: ignore # TODO: fix override OnlineDMD.predict_one
self, x: dict | np.ndarray, u: dict | np.ndarray
Expand All @@ -923,12 +994,15 @@ def predict_one( # type: ignore # TODO: fix override OnlineDMD.predict_one
"""
if isinstance(u, dict):
u = np.array(list(u.values()))
_m = len(x)
A, B = self._reconstruct_AB()

mat = np.zeros((2, self.m))
mat = np.zeros((2, _m))
mat[0, :] = x if isinstance(x, np.ndarray) else list(x.values())
for s in range(1, 2):
action = (self.B @ u).real
mat[s, :] = (self.A @ mat[s - 1, :]).real + action
action = (B @ u).real
# TODO: map A back to original space
mat[s, :] = (A @ mat[s - 1, :]).real + action
return mat[-1, :]

def predict_many( # type: ignore # TODO: fix override OnlineDMD.predict_many
Expand All @@ -953,12 +1027,14 @@ def predict_many( # type: ignore # TODO: fix override OnlineDMD.predict_many
"""
if isinstance(U, pd.DataFrame):
U = U.values
_m = len(x)
A, B = self._reconstruct_AB()

mat = np.zeros((forecast + 1, self.m))
mat = np.zeros((forecast + 1, _m))
mat[0, :] = x if isinstance(x, np.ndarray) else list(x.values())
for s in range(1, forecast + 1):
action = (self.B @ U[s - 1, :]).real
mat[s, :] = (self.A @ mat[s - 1, :]).real + action
action = (B @ U[s - 1, :]).real
mat[s, :] = (A @ mat[s - 1, :]).real + action
return mat[1:, :]

def truncation_error( # type: ignore # TODO: fix override OnlineDMD.truncation_error
Expand All @@ -977,5 +1053,7 @@ def truncation_error( # type: ignore # TODO: fix override OnlineDMD.truncation
Returns:
float: Truncation error of the DMD model
"""
Y_hat = self.A @ X.T + self.B @ U.T

A, B = self._reconstruct_AB()
Y_hat = A @ X.T + B @ U.T
return float(np.linalg.norm(Y - Y_hat.T) / np.linalg.norm(Y))
Loading

0 comments on commit c5b97f3

Please sign in to comment.