Skip to content

Commit

Permalink
Added a bunch of doc strings. Also maadded verbose argument to JAX im…
Browse files Browse the repository at this point in the history
…plementations to allow controlling when to print that functions will be JIT compiled. Also made some minor other changes.
  • Loading branch information
Sm00thix committed Nov 2, 2023
1 parent 72af6cc commit 00b258b
Show file tree
Hide file tree
Showing 5 changed files with 618 additions and 274 deletions.
212 changes: 151 additions & 61 deletions algorithms/jax_ikpls_alg_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,47 @@

class PLS(PLSBase):
"""
Implements partial least-squares regression using Improved Kernel PLS by Dayal and MacGregor: https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23
Description
-----------
Implements partial least-squares regression using Improved Kernel PLS Algorithm #1 by Dayal and MacGregor: https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23.
Parameters:
differentiable: Bool. Whether to make the implementation end-to-end differentiable. The differentiable version is slightly slower. Results among the two versions are identical. Defaults to False
Parameters
----------
`reverse_differentiable`: bool, optional (default=False). Whether to make the implementation end-to-end differentiable. The differentiable version is slightly slower. Results among the two versions are identical.
`verbose` : bool, optional (default=False). If True, each sub-function will print when it will be JIT compiled. This can be useful to track if recompilation is triggered due to passing inputs with different shapes.
"""

def __init__(self, differentiable: bool = False) -> None:
super().__init__(differentiable=differentiable)
def __init__(
self, reverse_differentiable: bool = False, verbose: bool = False
) -> None:
name = "Improved Kernel PLS Algorithm #1"
super().__init__(
name=name, reverse_differentiable=reverse_differentiable, verbose=verbose
)

@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
def _get_initial_matrices(
self, A: int, K: int, M: int, N: int
) -> Tuple[
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray
]:
print("Tracing initial matrices...")
if self.verbose:
print(f"_get_initial_matrices for {self.name} will be JIT compiled...")
B, W, P, Q, R = super()._get_initial_matrices(A, K, M)
T = jnp.empty(shape=(A, N), dtype=jnp.float64)
return B, W, P, Q, R, T

@partial(jax.jit, static_argnums=(0,))
def _step_1(self, X: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
print("Tracing step 1...")
if self.verbose:
print(f"_step_1 for {self.name} will be JIT compiled...")
return self._compute_initial_XTY(X.T, Y)

@partial(jax.jit, static_argnums=(0,))
def _step_4(self, X: jnp.ndarray, XTY: jnp.ndarray, r: jnp.ndarray):
print("Tracing step 4...")
if self.verbose:
print(f"_step_4 for {self.name} will be JIT compiled...")
t = X @ r
tT = t.T
tTt = tT @ t
Expand All @@ -57,10 +71,11 @@ def _main_loop_body(
) -> Tuple[
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray
]:
print("Tracing loop body...")
if self.verbose:
print(f"_main_loop_body for {self.name} will be JIT compiled...")
# step 2
w, norm = self._step_2(XTY, M, K)
host_callback.id_tap(self.weight_warning, [i, norm])
host_callback.id_tap(self._weight_warning, [i, norm])
# step 3
if differentiable:
r = self._step_3(A, w, P, R)
Expand All @@ -73,62 +88,137 @@ def _main_loop_body(
return XTY, w, p, q, r, t

def fit(self, X: jnp.ndarray, Y: jnp.ndarray, A: int) -> None:
self.B, _W, _P, _Q, _R, _T = self.stateless_fit(X, Y, A)
self.W = _W.T
self.P = _P.T
self.Q = _Q.T
self.R = _R.T
self.T = _T.T

"""
Description
-----------
Fits Improved Kernel PLS Algorithm #1 on `X` and `Y` using `A` components.
Parameters
----------
`X` : Array of shape (N, K)
Predictor variables. The precision should be at least float64 for reliable results.
`Y` : Array of shape (N, M)
Response variables. The precision should be at least float64 for reliable results.
`A` : int
Number of components in the PLS model.
Assigns
-------
`self.B` : Array of shape (A, K, M)
PLS regression coefficients tensor.
`self.W` : Array of shape (K, A)
PLS weights matrix for X.
`self.P` : Array of shape (K, A)
PLS loadings matrix for X.
`self.Q` : Array of shape (M, A)
PLS Loadings matrix for Y.
`self.R` : Array of shape (K, A)
PLS weights matrix to compute scores T directly from original X.
`self.T` : Array of shape (N, A)
PLS scores matrix of X.
Returns
-------
`None`.
Warns
-----
`UserWarning`.
If at any point during iteration over the number of components `A`, the residual goes below machine precision for jnp.float64.
See Also
--------
`stateless_fit` : Performs the same operation but returns the output matrices instead of storing them in the class instance.
"""
self.B, W, P, Q, R, T = self.stateless_fit(X, Y, A)
self.W = W.T
self.P = P.T
self.Q = Q.T
self.R = R.T
self.T = T.T

@partial(jax.jit, static_argnums=(0, 3))
def stateless_fit(
self, X: jnp.ndarray, Y: jnp.ndarray, A: int
) -> Tuple[
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray
]:
"""
Parameters:
X: Predictor variables matrix (N x K)
Y: Response variables matrix (N x M)
A: Number of components in the PLS model
Contains:
B: PLS regression coefficients matrix (A x K x M)
W: PLS weights matrix for X (K x A)
P: PLS loadings matrix for X (K x A)
Q: PLS Loadings matrix for Y (M x A)
R: PLS weights matrix to compute scores T directly from original X (K x A)
T: PLS scores matrix of X (N x A)
Parameters
----------
`X` : Array of shape (N, K)
Predictor variables. The precision should be at least float64 for reliable results.
`Y` : Array of shape (N, M)
Response variables. The precision should be at least float64 for reliable results.
`A` : int
Number of components in the PLS model.
Returns
-------
`B` : Array of shape (A, K, M)
PLS regression coefficients tensor.
`W` : Array of shape (A, K)
PLS weights matrix for X.
`P` : Array of shape (A, K)
PLS loadings matrix for X.
`Q` : Array of shape (A, M)
PLS Loadings matrix for Y.
`R` : Array of shape (A, K)
PLS weights matrix to compute scores T directly from original X.
`T` : Array of shape (A, N)
PLS scores matrix of X.
Warns
-----
`UserWarning`.
If at any point during iteration over the number of components `A`, the residual goes below machine precision for jnp.float64.
See Also
--------
fit : Performs the same operation but stores the output matrices in the class instance instead of returning them.
Notes
-----
For optimization purposes, the internal representation of all matrices (except B) is transposed from the usual representation.
"""

@partial(jax.jit, static_argnums=(2, 3))
def helper(
X: jnp.ndarray, Y: jnp.ndarray, A: int, differentiable: bool
) -> Tuple[
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray
]:
# Get shapes
N, K = X.shape
M = Y.shape[1]

# Initialize matrices
B, W, P, Q, R, T = self._get_initial_matrices(A, K, M, N)

# step 1
XTY = self._step_1(X, Y)

for i in range(A):
XTY, w, p, q, r, t = self._main_loop_body(
A, i, X, XTY, M, K, P, R, differentiable
)
W = W.at[i].set(w.squeeze())
P = P.at[i].set(p.squeeze())
Q = Q.at[i].set(q.squeeze())
R = R.at[i].set(r.squeeze())
T = T.at[i].set(t.squeeze())
b = self.compute_regression_coefficients(B[i - 1], r, q)
B = B.at[i].set(b)

return B, W, P, Q, R, T

return helper(X=X, Y=Y, A=A, differentiable=self.differentiable)
if self.verbose:
print(f"stateless_fit for {self.name} will be JIT compiled...")

# Get shapes
N, K = X.shape
M = Y.shape[1]

# Initialize matrices
B, W, P, Q, R, T = self._get_initial_matrices(A, K, M, N)

# step 1
XTY = self._step_1(X, Y)

for i in range(A):
XTY, w, p, q, r, t = self._main_loop_body(
A, i, X, XTY, M, K, P, R, self.reverse_differentiable
)
W = W.at[i].set(w.squeeze())
P = P.at[i].set(p.squeeze())
Q = Q.at[i].set(q.squeeze())
R = R.at[i].set(r.squeeze())
T = T.at[i].set(t.squeeze())
b = self._compute_regression_coefficients(B[i - 1], r, q)
B = B.at[i].set(b)

return B, W, P, Q, R, T
Loading

0 comments on commit 00b258b

Please sign in to comment.