Skip to content

Commit

Permalink
Fixed FutureDeprecationWarning arising from hashing of a tracer. This…
Browse files Browse the repository at this point in the history
… was hashing of the value i in PLSBase._step_3
  • Loading branch information
Sm00thix committed Jun 27, 2024
1 parent ee3de3e commit 9cf6cc7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ikpls/jax_ikpls_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def _step_1(self, X: jax.Array, Y: jax.Array):

@partial(jax.jit, static_argnums=(0, 2, 3))
def _step_2(
self, XTY: jax.Array, M: jax.Array, K: jax.Array
self, XTY: jax.Array, M: int, K: int
) -> Tuple[jax.Array, DTypeLike]:
"""
The second step of the PLS algorithm. Computes the next weight vector and the
Expand Down Expand Up @@ -516,7 +516,7 @@ def _step_2(
norm = eig_vals[-1]
return w, norm

@partial(jax.jit, static_argnums=(0, 1))
@partial(jax.jit, static_argnums=(0))
def _step_3(self, i: int, w: jax.Array, P: jax.Array, R: jax.Array) -> jax.Array:
"""
The third step of the PLS algorithm. Computes the orthogonal weight vectors.
Expand Down

0 comments on commit 9cf6cc7

Please sign in to comment.