From 19456d38de1738d1e44b19e9a91802097e3d0527 Mon Sep 17 00:00:00 2001 From: marekwadinger Date: Mon, 11 Mar 2024 08:43:14 +0900 Subject: [PATCH] UPDATE: DMD truncation + ADD: module's init; SVD sorting --- river/decomposition/__init__.py | 15 +++ river/decomposition/odmd.py | 197 ++++++++++++++++++++++---------- river/decomposition/osvd.py | 63 +++++++--- 3 files changed, 200 insertions(+), 75 deletions(-) create mode 100644 river/decomposition/__init__.py diff --git a/river/decomposition/__init__.py b/river/decomposition/__init__.py new file mode 100644 index 0000000000..e1c9752d62 --- /dev/null +++ b/river/decomposition/__init__.py @@ -0,0 +1,15 @@ +"""Decomposition. + +""" +from __future__ import annotations + +from .odmd import OnlineDMD, OnlineDMDwC +from .opca import OnlinePCA +from .osvd import OnlineSVD + +__all__ = [ + "OnlineSVD", + "OnlineDMD", + "OnlineDMDwC", + "OnlinePCA", +] diff --git a/river/decomposition/odmd.py b/river/decomposition/odmd.py index 6d476f9774..eebb85bb7f 100644 --- a/river/decomposition/odmd.py +++ b/river/decomposition/odmd.py @@ -23,12 +23,15 @@ from __future__ import annotations import warnings +from typing import Literal import numpy as np import pandas as pd import scipy as sp from scipy.sparse.linalg._eigen.arpack.arpack import ArpackNoConvergence +from .osvd import OnlineSVD + __all__ = [ "OnlineDMD", "OnlineDMDwC", @@ -152,6 +155,8 @@ def __init__( seed: int | None = None, ) -> None: self.r = int(r) + if self.r != 0: + self._svd = OnlineSVD(n_components=self.r, force_orth=True) self.w = float(w) assert self.w > 0 and self.w <= 1 self.initialize = int(initialize) @@ -170,25 +175,26 @@ def __init__( @property def eig(self) -> tuple[np.ndarray, np.ndarray]: """Compute and return DMD eigenvalues and DMD modes at current step""" + # TODO: need to check if SVD is initialized in case r < m. Otherwise, transformation will fail. try: - Lambda, Phi = sp.sparse.linalg.eigs(self.A, k=self.r) + Lambda, Phi = sp.linalg.eig(self.A, check_finite=False) except ArpackNoConvergence: Lambda, Phi = sp.linalg.schur(self.A, check_finite=False) - if self.r: - Lambda, Phi = Lambda[: self.r], Phi[:, : self.r] + # TODO: Figure out if we need to sort indices in descending order + if not np.array_equal(Lambda, sorted(Lambda, reverse=True)): + sort_idx = np.argsort(Lambda)[::-1] + Lambda = Lambda[sort_idx] + Phi = Phi[:, sort_idx] return Lambda, Phi - def _init_update(self) -> None: - if self.initialize > 0 and self.initialize < self.m: - warnings.warn( - f"Initialization is under-constrained. Set initialize={self.m} to supress this Warning." - ) - self.initialize = self.m - - self.A = np.random.randn(self.m, self.m) - self._X_init = np.empty((self.initialize, self.m)) - self._Y_init = np.empty((self.initialize, self.m)) - self._Y = np.empty((0, self.m)) + @property + def modes(self) -> np.ndarray: + """Reconstruct high dimensional DMD modes""" + _, Phi = self.eig + if self.r < self.m: + return self._svd._U @ np.diag(self._svd._S) @ Phi + else: + return Phi @property def xi(self) -> np.ndarray: @@ -202,13 +208,62 @@ def xi(self) -> np.ndarray: def objective_function(x): return np.linalg.norm( - self._Y.T - Phi @ np.diag(x) @ C, "fro" + self._Y[:, : self.r].T - Phi @ np.diag(x) @ C, "fro" ) + 0.5 * np.linalg.norm(x, 1) # Minimize the objective function - xi = minimize(objective_function, np.ones(self.m)).x + xi = minimize(objective_function, np.ones(self.r)).x return xi + def _init_update(self) -> None: + if self.initialize > 0 and self.initialize < self.m: + warnings.warn( + f"Initialization is under-constrained. Set initialize={self.m} to supress this Warning." + ) + self.initialize = self.m + if self.r == 0: + self.r = self.m + + self.A = np.random.randn(self.r, self.r) + self._X_init = np.empty((self.initialize, self.m)) + self._Y_init = np.empty((self.initialize, self.m)) + self._Y = np.empty((0, self.m)) + + def _truncate_w_svd( + self, + x: np.ndarray, + y: np.ndarray, + svd_modify: Literal["update", "revert"] | None = None, + ): + U_prev = self._svd._U + if svd_modify == "update": + self._svd.update(x.reshape(1, -1)) + elif svd_modify == "revert": + self._svd.revert(x.reshape(1, -1)) + _U = self._svd._U + _UU = _U.T @ U_prev + x = _U.T @ x + y = _U.T @ y + self.A = _UU @ self.A @ _UU.T + self._P = np.linalg.inv(_UU @ np.linalg.inv(self._P) @ _UU.T) / self.w + + return x, y + + def _update_A_P( + self, X: np.ndarray, Y: np.ndarray, W: float | np.ndarray + ) -> None: + Xt = X.T + AX = self.A.dot(Xt) + PX = self._P.dot(Xt) + PXt = PX.T + Gamma = np.linalg.inv(W + X.dot(PX)) + # update A on new data + self.A += (Y.T - AX).dot(Gamma).dot(PXt) + # update P, group Px*Px' to ensure positive definite + self._P = (self._P - PX.dot(Gamma).dot(PXt)) / self.w + # ensure P is SPD by taking its symmetric part + self._P = (self._P + self._P.T) / 2 + def update( self, x: dict | np.ndarray, @@ -237,42 +292,50 @@ def update( if isinstance(x, dict): self.feature_names_in_ = list(x.keys()) x = np.array(list(x.values())) + if len(x.shape) == 1: + x_ = x.reshape(1, -1) + else: + x_ = x if isinstance(y, dict): assert self.feature_names_in_ == list(y.keys()) y = np.array(list(y.values())) + if len(y.shape) == 1: + y_ = y.reshape(1, -1) + else: + y_ = y # Initialize properties which depend on the shape of x if self.n_seen == 0: self.m = len(x) self._init_update() + + # Collect buffer of past snapshots to compute xi + if self._Y.shape[0] <= self.n_seen: + self._Y = np.vstack([self._Y, y_]) + elif self._Y.shape[0] > self.n_seen: + self._Y = self._Y[self.n_seen :, :] + + # Initialize A and P with first self.initialize snapshot pairs if bool(self.initialize) and self.n_seen <= self.initialize - 1: - self._X_init[self.n_seen, :] = x - self._Y_init[self.n_seen, :] = y + self._X_init[self.n_seen, :] = x_ + self._Y_init[self.n_seen, :] = y_ if self.n_seen == self.initialize - 1: self.learn_many(self._X_init, self._Y_init) # revert the number of seen samples to avoid doubling self.n_seen -= self._X_init.shape[0] + # Update incrementally if initialized else: if self.n_seen == 0: epsilon = 1e-15 alpha = 1.0 / epsilon - self._P = alpha * np.identity(self.m) # inverse of cov(X) - # compute P*x matrix vector product beforehand - Px = self._P.dot(x) - # compute gamma - gamma = 1.0 / (1.0 + x.dot(Px)) - # update A - self.A += np.outer(gamma * (y - self.A.dot(x)), Px) - # update P, group Px*Px' to ensure positive definite - self._P = (self._P - gamma * np.outer(Px, Px)) / self.w - # ensure P is SPD by taking its symmetric part - self._P = (self._P + self._P.T) / 2 + self._P = alpha * np.identity(self.r) + + if self.r < self.m: + x_, y_ = self._truncate_w_svd(x_, y_, svd_modify="update") + + self._update_A_P(x_, y_, 1.0) self.n_seen += 1 - if self._Y.shape[0] < self.n_seen: - self._Y = np.vstack([self._Y, y]) - elif self._Y.shape[0] > self.n_seen: - self._Y = self._Y[self.n_seen :, :] def learn_one( self, @@ -292,8 +355,8 @@ def revert( Compatible with Rolling and TimeRolling wrappers. Args: - x: 1D array, shape (m, ), x(t) as in y(t) = f(t, x(t)) - y: 1D array, shape (m, ), y(t) as in y(t) = f(t, x(t)) + x: 1D array, shape (1, m), x(t) as in y(t) = f(t, x(t)) + y: 1D array, shape (1, m), y(t) as in y(t) = f(t, x(t)) """ if self.n_seen < self.initialize: raise RuntimeError( @@ -313,24 +376,28 @@ def revert( if isinstance(x, dict): x = np.array(list(x.values())) + if len(x.shape) == 1: + x_ = x.reshape(1, -1) + else: + x_ = x if isinstance(y, dict): y = np.array(list(y.values())) + if len(y.shape) == 1: + y = y.reshape(1, -1) + else: + y_ = y + + if self.r < self.m: + x_, y_ = self._truncate_w_svd(x_, y_, svd_modify=None) - # compute P*x matrix vector product beforehand # Apply exponential weighting factor if self.exponential_weighting: weight = 1.0 / -(self.w**self.n_seen) else: weight = -1.0 - Px = self._P.dot(x) - gamma = 1.0 / (weight + x.dot(Px)) - # update A - Ax = self.A.dot(x) - self.A += np.outer(gamma * (y - Ax), Px) - # update P, group Px*Px' to ensure positive definite - self._P = (self._P - gamma * np.outer(Px, Px)) / self.w - # ensure P is SPD by taking its symmetric part - self._P = (self._P + self._P.T) / 2 + + self._update_A_P(x_, y_, weight) + self.n_seen -= 1 def _update_many( @@ -359,16 +426,14 @@ def _update_many( weights = np.sqrt(self.w) ** np.arange(p - 1, -1, -1) else: weights = np.ones(p) - C = np.diag(weights) + # Zhang (2019): Gamma = (C^{-1} U^T P U )^{−1} ) + C_inv = np.diag(np.reciprocal(weights)) - Xt = X.T - AX = self.A.dot(Xt) - PX = self._P.dot(Xt) - PXt = PX.T - Gamma = np.linalg.inv(np.linalg.inv(C) + X.dot(PX)) - self.A += (Y.T - AX).dot(Gamma).dot(PXt) - self._P = (self._P - PX.dot(Gamma).dot(PXt)) / self.w - self._P = (self._P + self._P.T) / 2 + if isinstance(X, pd.DataFrame): + X = X.values + if isinstance(Y, pd.DataFrame): + Y = Y.values + self._update_A_P(X, Y, C_inv) def learn_many( self, @@ -402,6 +467,8 @@ def learn_many( # Initialize A and P with first p snapshot pairs if not hasattr(self, "_P"): self.m = X.shape[1] + if self.r == 0: + self.r = self.m assert n >= self.m and np.linalg.matrix_rank(X) == self.m # Exponential weighting factor - older snapshots are weighted less if self.exponential_weighting: @@ -411,11 +478,19 @@ def learn_many( else: weights = np.ones((n, 1)) Xqhat, Yqhat = weights * X, weights * Y - self.A = Yqhat.T.dot(np.linalg.pinv(Xqhat.T)) - self._P = np.linalg.inv(Xqhat.T.dot(Xqhat)) / self.w + if self.r < self.m: + self._svd.learn_many(Xqhat) + _U, _S, _V = self._svd._U, self._svd._S, self._svd._V + self.A = _U.T @ Yqhat.T @ _V.T @ np.diag(1 / _S) + self._P = np.linalg.inv(_U.T @ Xqhat.T @ Xqhat @ _U) / self.w + else: + self.A = Yqhat.T.dot(np.linalg.pinv(Xqhat.T)) + self._P = np.linalg.inv(Xqhat.T.dot(Xqhat)) / self.w + + # Store the last p snapshots for xi computation + self._Y = Yqhat self.n_seen += n self.initialize = 0 - self._Y = Y # Update incrementally if initialized # Zhang (2019): "single rank-s update is roughly the same as applying # the rank-1 formula s times" @@ -486,8 +561,8 @@ def transform_one(self, x: dict | np.ndarray) -> np.ndarray: if isinstance(x, dict): x = np.array(list(x.values())) - _, Phi = self.eig - return Phi.T @ x + M = self.modes + return x @ M def transform_many(self, X: np.ndarray | pd.DataFrame) -> np.ndarray: """ @@ -502,8 +577,8 @@ def transform_many(self, X: np.ndarray | pd.DataFrame) -> np.ndarray: if isinstance(X, pd.DataFrame): X = X.values - _, Phi = self.eig - return Phi.T @ X + M = self.modes + return X @ M class OnlineDMDwC(OnlineDMD): diff --git a/river/decomposition/osvd.py b/river/decomposition/osvd.py index 27cdbdc69c..0257395263 100644 --- a/river/decomposition/osvd.py +++ b/river/decomposition/osvd.py @@ -56,7 +56,7 @@ class OnlineSVD(MiniBatchTransformer): force_orth: If True, the algorithm will force the singular vectors to be orthogonal. *Note*: Significantly increases the computational cost. Attributes: - n_components_: Desired dimensionality of output data. + n_components: Desired dimensionality of output data. initialize: Number of initial samples to use for the initialization of the algorithm. The value must be greater than `n_components`. feature_names_in_: List of input features. _U: Left singular vectors. @@ -85,7 +85,6 @@ class OnlineSVD(MiniBatchTransformer): >>> svd.revert(X.iloc[-1].values.reshape(1, -1)) - TODO: fix revert method - following test should pass >>> svd.transform_one(X.iloc[0].to_dict()) {0: 2.3492, 1: 0.03840} @@ -108,12 +107,13 @@ def __init__( initialize: int = 0, force_orth: bool = False, ): - self.n_components_ = n_components + self.n_components = n_components if initialize <= n_components: self.initialize = n_components + 1 else: self.initialize = initialize - self.force_orth_ = force_orth + self.force_orth = force_orth + self.n_features_in_: int self.feature_names_in_: list self._U: np.ndarray @@ -124,24 +124,51 @@ def _orthogonalize(self, U_, Sigma_, V_): UQ, UR = np.linalg.qr(U_, mode="complete") VQ, VR = np.linalg.qr(V_, mode="complete") tU_, tSigma_, tV_ = sp.sparse.linalg.svds( - (UR @ np.diag(Sigma_) @ VR), k=2 + (UR @ np.diag(Sigma_) @ VR), k=self.n_components ) + tU_, tSigma_, tV_ = self._sort_svd(tU_, tSigma_, tV_) return UQ @ tU_, tSigma_, VQ @ tV_ + def _sort_svd(self, U, S, V): + """Sort the singular value decomposition in descending order. + + As sparse SVD does not guarantee the order of the singular values, we + need to sort the singular value decomposition in descending order. + """ + if not np.array_equal(S, sorted(S, reverse=True)): + sort_idx = np.argsort(S)[::-1] + S = S[sort_idx] + U = U[:, sort_idx] + V = V[sort_idx, :] + return U, S, V + + def _truncate_svd(self): + """Truncate the singular value decomposition to the n components. + + Full SVD returns the full matrices U, S, and V in correct order. If the + result acqisition is faster than sparse SVD, we combine the results of + full SVD with truncation. + """ + self._U = self._U[:, : self.n_components] + self._S = self._S[: self.n_components] + self._V = self._V[: self.n_components, :] + def update(self, x: dict | np.ndarray): if isinstance(x, dict): self.feature_names_in_ = list(x.keys()) x = np.array(list(x.values())) + x = x.reshape(1, -1) m = (x @ self._U).T p = x.T - self._U @ m P, _ = np.linalg.qr(p) Ra = P.T @ p z = np.zeros_like(m.T) K = np.block([[np.diag(self._S), m], [z, Ra]]) - U_, Sigma_, V_ = sp.sparse.linalg.svds(K, k=self.n_components_) + U_, Sigma_, V_ = sp.sparse.linalg.svds(K, k=self.n_components) + U_, Sigma_, V_ = self._sort_svd(U_, Sigma_, V_) U_ = np.column_stack((self._U, P)) @ U_ - V_ = V_[:, :2] @ self._V - if self.force_orth_ and not test_orthonormality(V_.T): + V_ = V_[:, : self.n_components] @ self._V + if self.force_orth and not test_orthonormality(V_.T): U_, Sigma_, V_ = self._orthogonalize(U_, Sigma_, V_) self._U, self._S, self._V = U_, Sigma_, V_ @@ -160,11 +187,12 @@ def revert(self, _: dict | np.ndarray): - np.row_stack((np.diag(self._S) @ n, 0.0)) @ np.row_stack((n, np.sqrt(1 - n.T @ n))).T ) - U_, Sigma_, V_ = sp.sparse.linalg.svds(K, k=2) - U_ = self._U @ U_[:2, :] + U_, Sigma_, V_ = sp.sparse.linalg.svds(K, k=self.n_components) + U_, Sigma_, V_ = self._sort_svd(U_, Sigma_, V_) + U_ = self._U @ U_[: self.n_components, :] V_ = V_ @ np.row_stack((self._V, Q.T)) - if self.force_orth_ and not test_orthonormality(U_): + if self.force_orth: # and not test_orthonormality(U_): U_, Sigma_, V_ = self._orthogonalize(U_, Sigma_, V_) self._U, self._S, self._V = U_, Sigma_, V_ @@ -184,9 +212,16 @@ def learn_many(self, X: np.ndarray | pd.DataFrame): for x in X: self.learn_one(x.reshape(1, -1)) else: - self._U, self._S, self._V = sp.sparse.linalg.svds( - X.T, k=self.n_components_ - ) + if self.n_components < self.n_features_in_: + self._U, self._S, self._V = sp.sparse.linalg.svds( + X.T, k=self.n_components + ) + self._U, self._S, self._V = self._sort_svd(self._U, self._S, self._V) + + else: + self._U, self._S, self._V = np.linalg.svd( + X.T, full_matrices=False + ) def transform_one(self, x: dict | np.ndarray) -> dict: if isinstance(x, dict):