Skip to content

Commit

Permalink
Dev (#32)
Browse files Browse the repository at this point in the history
* Fixed FutureDeprecationWarning arising from hashing of a tracer. This was hashing of the value i in PLSBase._step_3

* Updated workflows to run tests on latest macos instead of macos-12.

* Bumped ikpls version number

* Updated workflows to fall back to macos-12. Also fixed error for reverse mode differentiable JAX.
  • Loading branch information
Sm00thix authored Jun 28, 2024
1 parent ee3de3e commit d5a5fb6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ikpls/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.2.post3"
__version__ = "1.2.3"
47 changes: 44 additions & 3 deletions ikpls/jax_ikpls_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def _step_1(self, X: jax.Array, Y: jax.Array):

@partial(jax.jit, static_argnums=(0, 2, 3))
def _step_2(
self, XTY: jax.Array, M: jax.Array, K: jax.Array
self, XTY: jax.Array, M: int, K: int
) -> Tuple[jax.Array, DTypeLike]:
"""
The second step of the PLS algorithm. Computes the next weight vector and the
Expand Down Expand Up @@ -516,8 +516,7 @@ def _step_2(
norm = eig_vals[-1]
return w, norm

@partial(jax.jit, static_argnums=(0, 1))
def _step_3(self, i: int, w: jax.Array, P: jax.Array, R: jax.Array) -> jax.Array:
def _step_3_base(self, i: int, w: jax.Array, P: jax.Array, R: jax.Array) -> jax.Array:
"""
The third step of the PLS algorithm. Computes the orthogonal weight vectors.
Expand Down Expand Up @@ -553,6 +552,48 @@ def _step_3(self, i: int, w: jax.Array, P: jax.Array, R: jax.Array) -> jax.Array
r, P, w, R = jax.lax.fori_loop(0, i, self._step_3_body, (r, P, w, R))
return r

def _step_3(self, i: int, w: jax.Array, P: jax.Array, R: jax.Array) -> jax.Array:
"""
This is an API to the third step of the PLS algorithm. Computes the orthogonal
weight vectors.
Parameters
----------
i : int
The current component number in the PLS algorithm.
w : Array of shape (K, 1)
The current weight vector.
P : Array of shape (A, K)
The loadings matrix for the predictor variables.
R : Array of shape (A, K)
The weights matrix to compute scores `T` directly from the original
predictor variables.
Returns
-------
r : Array of shape (K, 1)
The orthogonal weight vector for the current component.
Notes
-----
This method compiles _step_3_base which in turn computes the orthogonal weight
vector `r` for the current component in the PLS algorithm. It is a key step for
calculating the loadings and weights matrices.
See Also
--------
_step_3_base : The third step of the PLS algorithm.
"""
if self.reverse_differentiable:
jax.jit(self._step_3_base, static_argnums=(0, 1))
return self._step_3_base(i, w, P, R)
else:
jax.jit(self._step_3_base, static_argnums=(0))
return self._step_3_base(i, w, P, R)

@partial(jax.jit, static_argnums=0)
def _step_3_body(
self, j: int, carry: Tuple[jax.Array, jax.Array, jax.Array, jax.Array]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ikpls"
version = "1.2.2.post3"
version = "1.2.3"
description = "Improved Kernel PLS and Fast Cross-Validation."
authors = ["Sm00thix <oleemail@icloud.com>"]
maintainers = ["Sm00thix <oleemail@icloud.com>"]
Expand Down

0 comments on commit d5a5fb6

Please sign in to comment.