Skip to content

Commit

Permalink
Major restructuring of tests. Also implemented low weight norm warnin…
Browse files Browse the repository at this point in the history
…g for JAX algorithms
  • Loading branch information
Sm00thix committed Oct 25, 2023
1 parent ae2c95c commit 5837812
Show file tree
Hide file tree
Showing 5 changed files with 517 additions and 622 deletions.
9 changes: 1 addition & 8 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- uses: actions/cache@v3
with:
path: ${{ matrix.path }}
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -57,4 +50,4 @@ jobs:
- name: Test with pytest
run: |
pip install pytest pytest-cov
pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=algorithms/ --cov-report=xml --cov-report=html
python${{ matrix.python-version }} -m pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=algorithms/ --cov-report=xml --cov-report=html
4 changes: 3 additions & 1 deletion algorithms/jax_ikpls_alg_1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from algorithms.jax_ikpls_base import PLSBase
import jax
from jax.experimental import host_callback
import jax.numpy as jnp
from functools import partial
from typing import Tuple
Expand Down Expand Up @@ -53,7 +54,8 @@ def _main_loop_body(
]:
print("Tracing loop body...")
# step 2
w = self._step_2(XTY, M, K)
w, norm = self._step_2(XTY, M, K)
host_callback.id_tap(self.weight_warning, [i, norm])
# step 3
r = self._step_3(i, w, P, R)
# step 4
Expand Down
4 changes: 3 additions & 1 deletion algorithms/jax_ikpls_alg_2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from algorithms.jax_ikpls_base import PLSBase
import jax
from jax.experimental import host_callback
import jax.numpy as jnp
from functools import partial
from typing import Tuple
Expand Down Expand Up @@ -53,7 +54,8 @@ def _main_loop_body(
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
print("Tracing loop body...")
# step 2
w = self._step_2(XTY, M, K)
w, norm = self._step_2(XTY, M, K)
host_callback.id_tap(self.weight_warning, [i, norm])
# step 3
r = self._step_3(i, w, P, R)
# step 4
Expand Down
18 changes: 15 additions & 3 deletions algorithms/jax_ikpls_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Tuple, Callable, Union, Any
from tqdm import tqdm
import numpy as np
import warnings


class PLSBase(abc.ABC):
Expand All @@ -16,6 +17,13 @@ class PLSBase(abc.ABC):
def __init__(self) -> None:
self.name = "PLS"

def weight_warning(self, arg, _transforms):
i, norm = arg
if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0):
warnings.warn(
f"Weight is close to zero. Results with A = {i} component(s) or higher may be unstable."
)

@partial(jax.jit, static_argnums=0)
def compute_regression_coefficients(
self, b_last: jnp.ndarray, r: jnp.ndarray, q: jnp.ndarray
Expand Down Expand Up @@ -64,7 +72,9 @@ def _step_1(self):
pass

@partial(jax.jit, static_argnums=(0, 2, 3))
def _step_2(self, XTY: jnp.ndarray, M: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarray:
def _step_2(
self, XTY: jnp.ndarray, M: jnp.ndarray, K: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.float64]:
print("Tracing step 2...")
if M == 1:
norm = jla.norm(XTY)
Expand All @@ -76,12 +86,14 @@ def _step_2(self, XTY: jnp.ndarray, M: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarr
q = eig_vecs[:, -1:]
q = q.reshape(-1, 1)
w = XTY @ q
w = w / jla.norm(w)
norm = jla.norm(w)
w = w / norm
else:
XTYYTX = XTY @ XTY.T
eig_vals, eig_vecs = jla.eigh(XTYYTX)
w = eig_vecs[:, -1:]
return w
norm = eig_vals[-1]
return w, norm

@partial(jax.jit, static_argnums=(0,))
def _step_3(
Expand Down
Loading

0 comments on commit 5837812

Please sign in to comment.