From 5837812deae64076675a7b48e02df498dae7278c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ole=20Engstr=C3=B8m?= Date: Wed, 25 Oct 2023 02:47:10 +0200 Subject: [PATCH] Major restructuring of tests. Also implemented low weight norm warning for JAX algorithms --- .github/workflows/python-package.yml | 9 +- algorithms/jax_ikpls_alg_1.py | 4 +- algorithms/jax_ikpls_alg_2.py | 4 +- algorithms/jax_ikpls_base.py | 18 +- tests/test_consistency.py | 1104 ++++++++++++-------------- 5 files changed, 517 insertions(+), 622 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5f48736..e6e7ec2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -35,13 +35,6 @@ jobs: with: python-version: ${{ matrix.python-version }} - - uses: actions/cache@v3 - with: - path: ${{ matrix.path }} - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip- - - name: Install dependencies run: | python -m pip install --upgrade pip @@ -57,4 +50,4 @@ jobs: - name: Test with pytest run: | pip install pytest pytest-cov - pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=algorithms/ --cov-report=xml --cov-report=html \ No newline at end of file + python${{ matrix.python-version }} -m pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=algorithms/ --cov-report=xml --cov-report=html \ No newline at end of file diff --git a/algorithms/jax_ikpls_alg_1.py b/algorithms/jax_ikpls_alg_1.py index 2e75872..a65bd5a 100644 --- a/algorithms/jax_ikpls_alg_1.py +++ b/algorithms/jax_ikpls_alg_1.py @@ -1,5 +1,6 @@ from algorithms.jax_ikpls_base import PLSBase import jax +from jax.experimental import host_callback import jax.numpy as jnp from functools import partial from typing import Tuple @@ -53,7 +54,8 @@ def _main_loop_body( ]: print("Tracing loop body...") # step 2 - w = self._step_2(XTY, M, K) + w, norm = self._step_2(XTY, M, K) + host_callback.id_tap(self.weight_warning, [i, norm]) # step 3 r = self._step_3(i, w, P, R) # step 4 diff --git a/algorithms/jax_ikpls_alg_2.py b/algorithms/jax_ikpls_alg_2.py index 5e0b94a..207e6df 100644 --- a/algorithms/jax_ikpls_alg_2.py +++ b/algorithms/jax_ikpls_alg_2.py @@ -1,5 +1,6 @@ from algorithms.jax_ikpls_base import PLSBase import jax +from jax.experimental import host_callback import jax.numpy as jnp from functools import partial from typing import Tuple @@ -53,7 +54,8 @@ def _main_loop_body( ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: print("Tracing loop body...") # step 2 - w = self._step_2(XTY, M, K) + w, norm = self._step_2(XTY, M, K) + host_callback.id_tap(self.weight_warning, [i, norm]) # step 3 r = self._step_3(i, w, P, R) # step 4 diff --git a/algorithms/jax_ikpls_base.py b/algorithms/jax_ikpls_base.py index 66bb6a0..b029442 100644 --- a/algorithms/jax_ikpls_base.py +++ b/algorithms/jax_ikpls_base.py @@ -6,6 +6,7 @@ from typing import Tuple, Callable, Union, Any from tqdm import tqdm import numpy as np +import warnings class PLSBase(abc.ABC): @@ -16,6 +17,13 @@ class PLSBase(abc.ABC): def __init__(self) -> None: self.name = "PLS" + def weight_warning(self, arg, _transforms): + i, norm = arg + if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0): + warnings.warn( + f"Weight is close to zero. Results with A = {i} component(s) or higher may be unstable." + ) + @partial(jax.jit, static_argnums=0) def compute_regression_coefficients( self, b_last: jnp.ndarray, r: jnp.ndarray, q: jnp.ndarray @@ -64,7 +72,9 @@ def _step_1(self): pass @partial(jax.jit, static_argnums=(0, 2, 3)) - def _step_2(self, XTY: jnp.ndarray, M: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarray: + def _step_2( + self, XTY: jnp.ndarray, M: jnp.ndarray, K: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.float64]: print("Tracing step 2...") if M == 1: norm = jla.norm(XTY) @@ -76,12 +86,14 @@ def _step_2(self, XTY: jnp.ndarray, M: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarr q = eig_vecs[:, -1:] q = q.reshape(-1, 1) w = XTY @ q - w = w / jla.norm(w) + norm = jla.norm(w) + w = w / norm else: XTYYTX = XTY @ XTY.T eig_vals, eig_vecs = jla.eigh(XTYYTX) w = eig_vecs[:, -1:] - return w + norm = eig_vals[-1] + return w, norm @partial(jax.jit, static_argnums=(0,)) def _step_3( diff --git a/tests/test_consistency.py b/tests/test_consistency.py index de1f370..0958ee0 100644 --- a/tests/test_consistency.py +++ b/tests/test_consistency.py @@ -4,6 +4,7 @@ from algorithms.numpy_ikpls import PLS as NpPLS # import load_data + from . import load_data import pytest @@ -26,9 +27,19 @@ def load_Y(self, values: list[str]) -> npt.NDArray[np.float_]: return target_values def fit_models(self, X, Y, n_components): + x_mean = X.mean(axis=0) + X -= x_mean + y_mean = Y.mean(axis=0) + Y -= y_mean + x_std = X.std(axis=0, ddof=1) + x_std[x_std == 0.0] = 1.0 + X /= x_std + y_std = Y.std(axis=0, ddof=1) + y_std[y_std == 0.0] = 1.0 + Y /= y_std jnp_X = jnp.array(X) jnp_Y = jnp.array(Y) - sk_pls = SkPLS(n_components=n_components) + sk_pls = SkPLS(n_components=n_components, scale=False) # Do not rescale again. np_pls_alg_1 = NpPLS(algorithm=1) np_pls_alg_2 = NpPLS(algorithm=2) jax_pls_alg_1 = JAX_Alg_1() @@ -105,27 +116,94 @@ def check_x_loadings( ): if n_good_components == -1: n_good_components = np_pls_alg_1.A - assert_allclose( - np.abs(np_pls_alg_1.P[..., :n_good_components]), - np.abs(sk_pls.x_loadings_[..., :n_good_components]), + # assert_allclose( + # np.abs(np_pls_alg_1.P[..., :n_good_components]), + # np.abs(sk_pls.x_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + # assert_allclose( + # np.abs(np_pls_alg_2.P[..., :n_good_components]), + # np.abs(sk_pls.x_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + # assert_allclose( + # np.abs(np.array(jax_pls_alg_1.P)[..., :n_good_components]), + # np.abs(sk_pls.x_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + # assert_allclose( + # np.abs(np.array(jax_pls_alg_2.P)[..., :n_good_components]), + # np.abs(sk_pls.x_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + + # We have rotational freedom. Therefore, check that loadings are parallel or antiparallel to eachother. + # We do this by taking the dot product between the normalized loadings of two different implementations and assert that they are either -1 or 1. + assert_allclose( + np.abs( + np.sum( + np_pls_alg_1.P[..., :n_good_components] + * sk_pls.x_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np_pls_alg_1.P[..., :n_good_components], axis=0) + * la.norm(sk_pls.x_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np_pls_alg_2.P[..., :n_good_components]), - np.abs(sk_pls.x_loadings_[..., :n_good_components]), + np.abs( + np.sum( + np_pls_alg_2.P[..., :n_good_components] + * sk_pls.x_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np_pls_alg_2.P[..., :n_good_components], axis=0) + * la.norm(sk_pls.x_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np.array(jax_pls_alg_1.P)[..., :n_good_components]), - np.abs(sk_pls.x_loadings_[..., :n_good_components]), + np.abs( + np.sum( + np.array(jax_pls_alg_1.P[..., :n_good_components]) + * sk_pls.x_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np.array(jax_pls_alg_1.P[..., :n_good_components]), axis=0) + * la.norm(sk_pls.x_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np.array(jax_pls_alg_2.P)[..., :n_good_components]), - np.abs(sk_pls.x_loadings_[..., :n_good_components]), + np.abs( + np.sum( + np.array(jax_pls_alg_2.P[..., :n_good_components]) + * sk_pls.x_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np.array(jax_pls_alg_2.P[..., :n_good_components]), axis=0) + * la.norm(sk_pls.x_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) @@ -143,30 +221,96 @@ def check_y_loadings( ): if n_good_components == -1: n_good_components = np_pls_alg_1.A - assert_allclose( - np.abs(np_pls_alg_1.Q[..., :n_good_components]), - np.abs(sk_pls.y_loadings_[..., :n_good_components]), + # We have rotational freedom. Therefore, check that loadings are parallel or antiparallel to eachother. + # We do this by taking the dot product between the normalized loadings of two different implementations and assert that they are either -1 or 1. + assert_allclose( + np.abs( + np.sum( + np_pls_alg_1.Q[..., :n_good_components] + * sk_pls.y_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np_pls_alg_1.Q[..., :n_good_components], axis=0) + * la.norm(sk_pls.y_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np_pls_alg_2.Q[..., :n_good_components]), - np.abs(sk_pls.y_loadings_[..., :n_good_components]), + np.abs( + np.sum( + np_pls_alg_2.Q[..., :n_good_components] + * sk_pls.y_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np_pls_alg_2.Q[..., :n_good_components], axis=0) + * la.norm(sk_pls.y_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np.array(jax_pls_alg_1.Q)[..., :n_good_components]), - np.abs(sk_pls.y_loadings_[..., :n_good_components]), + np.abs( + np.sum( + np.array(jax_pls_alg_1.Q[..., :n_good_components]) + * sk_pls.y_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np.array(jax_pls_alg_1.Q[..., :n_good_components]), axis=0) + * la.norm(sk_pls.y_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np.array(jax_pls_alg_2.Q)[..., :n_good_components]), - np.abs(sk_pls.y_loadings_[..., :n_good_components]), + np.abs( + np.sum( + np.array(jax_pls_alg_2.Q[..., :n_good_components]) + * sk_pls.y_loadings_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np.array(jax_pls_alg_2.Q[..., :n_good_components]), axis=0) + * la.norm(sk_pls.y_loadings_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) + # assert_allclose( + # np.abs(np_pls_alg_1.Q[..., :n_good_components]), + # np.abs(sk_pls.y_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + # assert_allclose( + # np.abs(np_pls_alg_2.Q[..., :n_good_components]), + # np.abs(sk_pls.y_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + # assert_allclose( + # np.abs(np.array(jax_pls_alg_1.Q)[..., :n_good_components]), + # np.abs(sk_pls.y_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + # assert_allclose( + # np.abs(np.array(jax_pls_alg_2.Q)[..., :n_good_components]), + # np.abs(sk_pls.y_loadings_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) def check_x_rotations( self, @@ -181,27 +325,70 @@ def check_x_rotations( ): if n_good_components == -1: n_good_components = np_pls_alg_1.A - assert_allclose( - np.abs(np_pls_alg_1.R[..., :n_good_components]), - np.abs(sk_pls.x_rotations_[..., :n_good_components]), + + # We have rotational freedom. Therefore, check that rotations are parallel or antiparallel to eachother. + # We do this by taking the dot product between the normalized rotations of two different implementations and assert that they are either -1 or 1. + assert_allclose( + np.abs( + np.sum( + np_pls_alg_1.R[..., :n_good_components] + * sk_pls.x_rotations_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np_pls_alg_1.R[..., :n_good_components], axis=0) + * la.norm(sk_pls.x_rotations_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np_pls_alg_2.R[..., :n_good_components]), - np.abs(sk_pls.x_rotations_[..., :n_good_components]), + np.abs( + np.sum( + np_pls_alg_2.R[..., :n_good_components] + * sk_pls.x_rotations_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np_pls_alg_2.R[..., :n_good_components], axis=0) + * la.norm(sk_pls.x_rotations_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np.array(jax_pls_alg_1.R[..., :n_good_components])), - np.abs(sk_pls.x_rotations_[..., :n_good_components]), + np.abs( + np.sum( + np.array(jax_pls_alg_1.R[..., :n_good_components]) + * sk_pls.x_rotations_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np.array(jax_pls_alg_1.R[..., :n_good_components]), axis=0) + * la.norm(sk_pls.x_rotations_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np.array(jax_pls_alg_2.R[..., :n_good_components])), - np.abs(sk_pls.x_rotations_[..., :n_good_components]), + np.abs( + np.sum( + np.array(jax_pls_alg_2.R[..., :n_good_components]) + * sk_pls.x_rotations_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np.array(jax_pls_alg_2.R[..., :n_good_components]), axis=0) + * la.norm(sk_pls.x_rotations_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) @@ -217,18 +404,52 @@ def check_x_scores( # X scores - not computed by IKPLS Algorithm #2 ): if n_good_components == -1: n_good_components = np_pls_alg_1.A - assert_allclose( - np.abs(np_pls_alg_1.T[..., :n_good_components]), - np.abs(sk_pls.x_scores_[..., :n_good_components]), + # We have rotational freedom. Therefore, check that scores are parallel or antiparallel to eachother. + # We do this by taking the dot product between the normalized scores of two different implementations and assert that they are either -1 or 1. + assert_allclose( + np.abs( + np.sum( + np_pls_alg_1.T[..., :n_good_components] + * sk_pls.x_scores_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np_pls_alg_1.T[..., :n_good_components], axis=0) + * la.norm(sk_pls.x_scores_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) assert_allclose( - np.abs(np.array(jax_pls_alg_1.T[..., :n_good_components])), - np.abs(sk_pls.x_scores_[..., :n_good_components]), + np.abs( + np.sum( + np.array(jax_pls_alg_1.T)[..., :n_good_components] + * sk_pls.x_scores_[..., :n_good_components], + axis=0, + ) + / ( + la.norm(np.array(jax_pls_alg_1.T)[..., :n_good_components], axis=0) + * la.norm(sk_pls.x_scores_[..., :n_good_components], axis=0) + ) + ), + 1, atol=atol, rtol=rtol, ) + # assert_allclose( + # np.abs(np_pls_alg_1.T[..., :n_good_components]), + # np.abs(sk_pls.x_scores_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) + # assert_allclose( + # np.abs(np.array(jax_pls_alg_1.T[..., :n_good_components])), + # np.abs(sk_pls.x_scores_[..., :n_good_components]), + # atol=atol, + # rtol=rtol, + # ) def check_regression_matrices( self, @@ -284,6 +505,13 @@ def check_predictions( n_good_components = np_pls_alg_1.A # Check predictions for each and all possible number of components. sk_all_preds = X @ sk_B + diff = ( + np_pls_alg_1.predict(X)[:n_good_components] + - sk_all_preds[:n_good_components] + ) + max_atol = np.amax(diff) + max_rtol = np.amax(diff / np.abs(sk_all_preds[:n_good_components])) + print(f"Max atol: {max_atol}\nMax rtol:{max_rtol}") assert_allclose( np_pls_alg_1.predict(X)[:n_good_components], sk_all_preds[:n_good_components], @@ -415,15 +643,7 @@ def test_pls_1(self): Test PLS1. """ X = self.load_X() - X -= np.mean( - X, axis=0 - ) # SkLearn's PLS implementation always centers its X input. This ensures that the X input is always centered for all algorithms. - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability Y = self.load_Y(["Protein"]) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability n_components = 25 assert Y.shape[1] == 1 ( @@ -487,7 +707,7 @@ def test_pls_1(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=1e-8, + atol=0, rtol=1e-5, ) @@ -495,8 +715,8 @@ def test_pls_1(self): sk_pls=sk_pls, np_pls_alg_1=np_pls_alg_1, jax_pls_alg_1=jax_pls_alg_1, - atol=1e-3, - rtol=0, + atol=1e-8, + rtol=1e-5, ) self.check_regression_matrices( @@ -517,7 +737,7 @@ def test_pls_1(self): jax_pls_alg_2=jax_pls_alg_2, X=X, atol=1e-8, - rtol=0, + rtol=1e-5, ) # PLS1 is very numerically stable for protein. def test_pls_2_m_less_k(self): @@ -525,10 +745,6 @@ def test_pls_2_m_less_k(self): Test PLS2 where the number of targets is less than the number of features (M < K). """ X = self.load_X() - X -= np.mean( - X, axis=0 - ) # SkLearn's PLS implementation always centers its X input. This ensures that the X input is always centered for all algorithms. - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability Y = self.load_Y( [ "Rye_Midsummer", @@ -543,10 +759,6 @@ def test_pls_2_m_less_k(self): "Protein", ] ) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability assert Y.shape[1] > 1 assert Y.shape[1] < X.shape[1] n_components = 25 @@ -582,7 +794,7 @@ def test_pls_2_m_less_k(self): jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, atol=2e-3, - rtol=0, + rtol=1e-5, ) self.check_x_loadings( @@ -591,8 +803,8 @@ def test_pls_2_m_less_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=2e-3, - rtol=0, + atol=1e-8, + rtol=2e-5, ) self.check_y_loadings( @@ -601,8 +813,8 @@ def test_pls_2_m_less_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=0.14, - rtol=0, + atol=1e-8, + rtol=2e-5, ) self.check_x_rotations( @@ -611,16 +823,16 @@ def test_pls_2_m_less_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=2e-3, - rtol=0, + atol=0, + rtol=2e-5, ) self.check_x_scores( sk_pls=sk_pls, np_pls_alg_1=np_pls_alg_1, jax_pls_alg_1=jax_pls_alg_1, - atol=2e-3, - rtol=0, + atol=1e-8, + rtol=2e-5, ) self.check_regression_matrices( @@ -649,10 +861,6 @@ def test_pls_2_m_eq_k(self): """ X = self.load_X() X = X[..., :10] - X -= np.mean( - X, axis=0 - ) # SkLearn's PLS implementation always centers its X input. This ensures that the X input is always centered for all algorithms. - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability Y = self.load_Y( [ "Rye_Midsummer", @@ -667,10 +875,6 @@ def test_pls_2_m_eq_k(self): "Protein", ] ) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] n_components = 10 @@ -715,8 +919,8 @@ def test_pls_2_m_eq_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=2e-3, - rtol=0, + atol=1e-8, + rtol=1e-5, ) self.check_y_loadings( @@ -725,8 +929,8 @@ def test_pls_2_m_eq_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=0.23, - rtol=0, + atol=1e-8, + rtol=1e-5, ) self.check_x_rotations( @@ -735,16 +939,16 @@ def test_pls_2_m_eq_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=2e-3, - rtol=0, + atol=0, + rtol=1e-5, ) self.check_x_scores( sk_pls=sk_pls, np_pls_alg_1=np_pls_alg_1, jax_pls_alg_1=jax_pls_alg_1, - atol=3e-4, - rtol=0, + atol=1e-8, + rtol=1e-5, ) self.check_regression_matrices( @@ -753,8 +957,8 @@ def test_pls_2_m_eq_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=0.11, - rtol=0, + atol=1e-8, + rtol=0.1, ) self.check_predictions( sk_B=sk_B, @@ -773,10 +977,6 @@ def test_pls_2_m_greater_k(self): """ X = self.load_X() X = X[..., :9] - X -= np.mean( - X, axis=0 - ) # SkLearn's PLS implementation always centers its X input. This ensures that the X input is always centered for all algorithms. - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability Y = self.load_Y( [ "Rye_Midsummer", @@ -791,10 +991,6 @@ def test_pls_2_m_greater_k(self): "Protein", ] ) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability assert Y.shape[1] > 1 assert Y.shape[1] > X.shape[1] n_components = 9 @@ -833,122 +1029,6 @@ def test_pls_2_m_greater_k(self): rtol=0, ) - self.check_x_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=2e-3, - rtol=0, - ) - - self.check_y_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.17, - rtol=0, - ) - - self.check_x_rotations( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=2e-3, - rtol=0, - ) - - self.check_x_scores( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - atol=3e-4, - rtol=0, - ) - - self.check_regression_matrices( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.09, - rtol=0, - ) - self.check_predictions( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - X=X, - atol=2e-3, - rtol=0, - ) # PLS2 is not as numerically stable as PLS1. - - def test_early_stop_fitting_pls_1(self): - """ - The NumPy implementations will stop iterating through components if the residual comes close to 0. - """ - vectors = np.array( - [np.arange(2, 7, dtype=np.float64) ** (i + 1) for i in range(5)] - ) - X = np.tile(vectors, reps=(5000, 3)) - X -= np.mean(X, axis=0) - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability - - Y = np.sum(X, axis=1, keepdims=True) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability - n_components = 10 - n_good_components = 4 # X has a rank of 4. - assert la.matrix_rank(X) == 4 - assert Y.shape[1] == 1 - ( - sk_pls, - sk_B, - np_pls_alg_1, - np_pls_alg_2, - jax_pls_alg_1, - jax_pls_alg_2, - ) = self.fit_models(X=X, Y=Y, n_components=n_components) - - self.check_equality_properties( - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - X=X, - atol=1e-1, - rtol=1e-5, - n_good_components=n_good_components, - ) - self.check_orthogonality_properties( - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-1, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_x_weights( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.014, - rtol=0, - n_good_components=n_good_components, - ) - self.check_x_loadings( sk_pls=sk_pls, np_pls_alg_1=np_pls_alg_1, @@ -957,7 +1037,6 @@ def test_early_stop_fitting_pls_1(self): jax_pls_alg_2=jax_pls_alg_2, atol=1e-8, rtol=1e-5, - n_good_components=n_good_components, ) self.check_y_loadings( @@ -968,7 +1047,6 @@ def test_early_stop_fitting_pls_1(self): jax_pls_alg_2=jax_pls_alg_2, atol=1e-8, rtol=1e-5, - n_good_components=n_good_components, ) self.check_x_rotations( @@ -977,264 +1055,16 @@ def test_early_stop_fitting_pls_1(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=5e-4, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_x_scores( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - atol=2e-4, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_regression_matrices( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - self.check_predictions( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - X=X, - atol=1e-8, - rtol=0, - n_good_components=n_good_components, - ) - - def test_early_stop_fitting_pls_2_m_less_k(self): - """ - The NumPy implementations will stop iterating through components if the residual comes close to 0. - """ - vectors = np.array( - [np.arange(2, 7, dtype=np.float64) ** (i + 1) for i in range(5)] - ) - X = np.tile(vectors, reps=(5000, 3)) - X -= np.mean(X, axis=0) - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability - - Y = np.hstack( - (np.sum(X, axis=1, keepdims=True), np.sum(X, axis=1, keepdims=True) ** 2) - ) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability - n_components = 10 - n_good_components = 4 # X has a rank of 4. - assert la.matrix_rank(X) == 4 - assert Y.shape[1] > 1 - assert Y.shape[1] < X.shape[1] - ( - sk_pls, - sk_B, - np_pls_alg_1, - np_pls_alg_2, - jax_pls_alg_1, - jax_pls_alg_2, - ) = self.fit_models(X=X, Y=Y, n_components=n_components) - self.check_equality_properties( - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - X=X, - atol=1e-1, - rtol=1e-5, - n_good_components=n_good_components, - ) - self.check_orthogonality_properties( - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-1, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_x_weights( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_x_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_y_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=5e-8, + atol=0, rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_x_rotations( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.26, - rtol=0, - n_good_components=n_good_components, ) self.check_x_scores( sk_pls=sk_pls, np_pls_alg_1=np_pls_alg_1, jax_pls_alg_1=jax_pls_alg_1, - atol=2e-4, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_regression_matrices( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.33, - rtol=0, - n_good_components=n_good_components, - ) - self.check_predictions( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - X=X, atol=1e-8, - rtol=0, - n_good_components=n_good_components, - ) - - def test_early_stop_fitting_pls_2_m_eq_k(self): - """ - The NumPy implementations will stop iterating through components if the residual comes close to 0. - """ - vectors = np.array( - [np.arange(2, 7, dtype=np.float64) ** (i + 1) for i in range(5)] - ) - X = np.tile(vectors, reps=(5000, 2)) - X -= np.mean(X, axis=0) - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability - - Y = np.hstack([np.sum(X, axis=1, keepdims=True) ** (i + 1) for i in range(10)]) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability - n_components = 10 - n_good_components = 4 # X has a rank of 4. - assert la.matrix_rank(X) == 4 - assert Y.shape[1] > 1 - assert Y.shape[1] == X.shape[1] - ( - sk_pls, - sk_B, - np_pls_alg_1, - np_pls_alg_2, - jax_pls_alg_1, - jax_pls_alg_2, - ) = self.fit_models(X=X, Y=Y, n_components=n_components) - self.check_equality_properties( - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - X=X, - atol=1e-1, rtol=1e-5, - n_good_components=n_good_components, - ) - self.check_orthogonality_properties( - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-1, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_x_weights( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_x_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_y_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=3e-7, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_x_rotations( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.33, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_x_scores( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - atol=2e-4, - rtol=0, - n_good_components=n_good_components, ) self.check_regression_matrices( @@ -1243,131 +1073,8 @@ def test_early_stop_fitting_pls_2_m_eq_k(self): np_pls_alg_2=np_pls_alg_2, jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, - atol=0.3, - rtol=0, - n_good_components=n_good_components, - ) - self.check_predictions( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - X=X, - atol=1e-8, - rtol=0, - n_good_components=n_good_components, - ) - - def test_early_stop_fitting_pls_2_m_greater_k(self): - """ - The NumPy implementations will stop iterating through components if the residual comes close to 0. - """ - vectors = np.array( - [np.arange(2, 7, dtype=np.float64) ** (i + 1) for i in range(5)] - ) - X = np.tile(vectors, reps=(5000, 3)) - X -= np.mean(X, axis=0) - X /= np.std(X, axis=0) # Let's also scale them for better numerical stability - - Y = np.hstack([np.sum(X, axis=1, keepdims=True) ** (i + 1) for i in range(20)]) - Y -= np.mean( - Y, axis=0, keepdims=True - ) # SkLearn's PLS implementation always centers its Y input. This ensures that the Y input is always centered for all algorithms. - Y /= np.std(Y, axis=0, keepdims=True) # Scale for numerical stability - n_components = 10 - n_good_components = 4 # X has a rank of 4. - assert la.matrix_rank(X) == 4 - assert Y.shape[1] > 1 - assert Y.shape[1] > X.shape[1] - ( - sk_pls, - sk_B, - np_pls_alg_1, - np_pls_alg_2, - jax_pls_alg_1, - jax_pls_alg_2, - ) = self.fit_models(X=X, Y=Y, n_components=n_components) - self.check_equality_properties( - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - X=X, - atol=1e-1, - rtol=1e-5, - n_good_components=n_good_components, - ) - self.check_orthogonality_properties( - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-1, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_x_weights( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, atol=1e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_x_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=1e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_y_loadings( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=6e-8, - rtol=1e-5, - n_good_components=n_good_components, - ) - - self.check_x_rotations( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.38, - rtol=0.69, - n_good_components=n_good_components, - ) - - self.check_x_scores( - sk_pls=sk_pls, - np_pls_alg_1=np_pls_alg_1, - jax_pls_alg_1=jax_pls_alg_1, - atol=2e-4, - rtol=0, - n_good_components=n_good_components, - ) - - self.check_regression_matrices( - sk_B=sk_B, - np_pls_alg_1=np_pls_alg_1, - np_pls_alg_2=np_pls_alg_2, - jax_pls_alg_1=jax_pls_alg_1, - jax_pls_alg_2=jax_pls_alg_2, - atol=0.22, - rtol=0, - n_good_components=n_good_components, + rtol=2e-2, ) self.check_predictions( sk_B=sk_B, @@ -1376,10 +1083,9 @@ def test_early_stop_fitting_pls_2_m_greater_k(self): jax_pls_alg_1=jax_pls_alg_1, jax_pls_alg_2=jax_pls_alg_2, X=X, - atol=1e-8, + atol=2e-3, rtol=0, - n_good_components=n_good_components, - ) + ) # PLS2 is not as numerically stable as PLS1. def test_sanity_check_pls_regression( self, @@ -1388,11 +1094,7 @@ def test_sanity_check_pls_regression( d = load_linnerud() X = d.data # Shape = (20,3) - X -= np.mean(X, axis=0) - X /= np.std(X, axis=0) Y = d.target # Shape = (20,3) - Y -= np.mean(Y, axis=0) - Y /= np.std(Y, axis=0) n_components = X.shape[1] # 3 ( sk_pls, @@ -1569,14 +1271,8 @@ def test_sanity_check_pls_regression_constant_column_Y( d = load_linnerud() X = d.data # Shape = (20,3) - X -= np.mean(X, axis=0) - X /= np.std(X, axis=0) Y = d.target # Shape = (20,3) Y[:, 0] = 1 # Set the first column to a constant - Y -= np.mean(Y, axis=0) - Y[:, 1:] /= np.std( - Y[:, 1:], axis=0 - ) # The standard deviation of the first column is 0 by construction. Do not attempt division on that column. n_components = X.shape[1] # 3 ( sk_pls, @@ -1753,25 +1449,215 @@ def test_sanity_check_pls_regression_constant_column_Y( rtol=0, ) -# if __name__ == "__main__": -# tc = TestClass() -# tc.test_pls_1() -# tc.test_pls_2_m_less_k() -# tc.test_pls_2_m_eq_k() -# tc.test_pls_2_m_greater_k() -# tc.test_early_stop_fitting_pls_1() # Stop after 4 components. Here, own algorithms fails to stop early. Norm is constant at approx. 1e-14. -# tc.test_early_stop_fitting_pls_2_m_less_k() # Stop after 4 components. Here, own algorithms fails to stop early. Norm explodes. -# tc.test_early_stop_fitting_pls_2_m_eq_k() # Stop after 4 components -# tc.test_early_stop_fitting_pls_2_m_greater_k() # Fix denne. Lykkes ikke med at skaffe early stopping -# tc.test_sanity_check_pls_regression() -# tc.test_sanity_check_pls_regression_constant_column_Y() - -# TODO: Make individual tests for each of the internal matrices. (DONE) -# TODO: Add the remaining tests from SkLearn's test suite. + def test_pls_1_constant_y( + self, + ): # Taken from SkLearn's test suite and modified to include own algorithms. + """Checks warning when y is constant.""" + rng = np.random.RandomState(42) + X = rng.rand(100, 3) + Y = np.zeros(shape=(100, 1)) + n_components = 2 + + ## Taken from self.fit_models() to check each individual algorithm for early stopping. + x_mean = X.mean(axis=0) + X -= x_mean + y_mean = Y.mean(axis=0) + Y -= y_mean + x_std = X.std(axis=0, ddof=1) + x_std[x_std == 0.0] = 1.0 + X /= x_std + y_std = Y.std(axis=0, ddof=1) + y_std[y_std == 0.0] = 1.0 + Y /= y_std + jnp_X = jnp.array(X) + jnp_Y = jnp.array(Y) + sk_pls = SkPLS(n_components=n_components, scale=False) # Do not rescale again. + np_pls_alg_1 = NpPLS(algorithm=1) + np_pls_alg_2 = NpPLS(algorithm=2) + jax_pls_alg_1 = JAX_Alg_1() + jax_pls_alg_2 = JAX_Alg_2() + + assert Y.shape[1] == 1 + + sk_msg = "Y residual is constant at iteration" + with pytest.warns(UserWarning, match=sk_msg): + sk_pls.fit(X=X, Y=Y) + assert_allclose(sk_pls.x_rotations_, 0) + + msg = "Weight is close to zero." + with pytest.warns(UserWarning, match=msg): + np_pls_alg_1.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_1.R, 0) + with pytest.warns(UserWarning, match=msg): + np_pls_alg_2.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_2.R, 0) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_1.fit(X=jnp_X, Y=jnp_Y, A=n_components) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_2.fit(X=jnp_X, Y=jnp_Y, A=n_components) + + def test_pls_2_m_less_k_constant_y( + self, + ): # Taken from SkLearn's test suite and modified to include own algorithms. + """Checks warning when y is constant.""" + rng = np.random.RandomState(42) + X = rng.rand(100, 3) + Y = np.zeros(shape=(100, 2)) + n_components = 2 + + ## Taken from self.fit_models() to check each individual algorithm for early stopping. + x_mean = X.mean(axis=0) + X -= x_mean + y_mean = Y.mean(axis=0) + Y -= y_mean + x_std = X.std(axis=0, ddof=1) + x_std[x_std == 0.0] = 1.0 + X /= x_std + y_std = Y.std(axis=0, ddof=1) + y_std[y_std == 0.0] = 1.0 + Y /= y_std + jnp_X = jnp.array(X) + jnp_Y = jnp.array(Y) + sk_pls = SkPLS(n_components=n_components, scale=False) # Do not rescale again. + np_pls_alg_1 = NpPLS(algorithm=1) + np_pls_alg_2 = NpPLS(algorithm=2) + jax_pls_alg_1 = JAX_Alg_1() + jax_pls_alg_2 = JAX_Alg_2() + + assert Y.shape[1] > 1 + assert Y.shape[1] < X.shape[1] + + sk_msg = "Y residual is constant at iteration" + with pytest.warns(UserWarning, match=sk_msg): + sk_pls.fit(X=X, Y=Y) + assert_allclose(sk_pls.x_rotations_, 0) + + msg = "Weight is close to zero." + with pytest.warns(UserWarning, match=msg): + np_pls_alg_1.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_1.R, 0) + with pytest.warns(UserWarning, match=msg): + np_pls_alg_2.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_2.R, 0) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_1.fit(X=jnp_X, Y=jnp_Y, A=n_components) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_2.fit(X=jnp_X, Y=jnp_Y, A=n_components) + + def test_pls_2_m_eq_k_constant_y( + self, + ): # Taken from SkLearn's test suite and modified to include own algorithms. + """Checks warning when y is constant.""" + rng = np.random.RandomState(42) + X = rng.rand(100, 3) + Y = np.zeros(shape=(100, 3)) + n_components = 2 + + ## Taken from self.fit_models() to check each individual algorithm for early stopping. + x_mean = X.mean(axis=0) + X -= x_mean + y_mean = Y.mean(axis=0) + Y -= y_mean + x_std = X.std(axis=0, ddof=1) + x_std[x_std == 0.0] = 1.0 + X /= x_std + y_std = Y.std(axis=0, ddof=1) + y_std[y_std == 0.0] = 1.0 + Y /= y_std + jnp_X = jnp.array(X) + jnp_Y = jnp.array(Y) + sk_pls = SkPLS(n_components=n_components, scale=False) # Do not rescale again. + np_pls_alg_1 = NpPLS(algorithm=1) + np_pls_alg_2 = NpPLS(algorithm=2) + jax_pls_alg_1 = JAX_Alg_1() + jax_pls_alg_2 = JAX_Alg_2() + + assert Y.shape[1] > 1 + assert Y.shape[1] == X.shape[1] + + sk_msg = "Y residual is constant at iteration" + with pytest.warns(UserWarning, match=sk_msg): + sk_pls.fit(X=X, Y=Y) + assert_allclose(sk_pls.x_rotations_, 0) + + msg = "Weight is close to zero." + with pytest.warns(UserWarning, match=msg): + np_pls_alg_1.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_1.R, 0) + with pytest.warns(UserWarning, match=msg): + np_pls_alg_2.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_2.R, 0) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_1.fit(X=jnp_X, Y=jnp_Y, A=n_components) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_2.fit(X=jnp_X, Y=jnp_Y, A=n_components) + + def test_pls_2_m_greater_k_constant_y( + self, + ): # Taken from SkLearn's test suite and modified to include own algorithms. + """Checks warning when y is constant.""" + rng = np.random.RandomState(42) + X = rng.rand(100, 3) + Y = np.zeros(shape=(100, 4)) + n_components = 2 + + ## Taken from self.fit_models() to check each individual algorithm for early stopping. + x_mean = X.mean(axis=0) + X -= x_mean + y_mean = Y.mean(axis=0) + Y -= y_mean + x_std = X.std(axis=0, ddof=1) + x_std[x_std == 0.0] = 1.0 + X /= x_std + y_std = Y.std(axis=0, ddof=1) + y_std[y_std == 0.0] = 1.0 + Y /= y_std + jnp_X = jnp.array(X) + jnp_Y = jnp.array(Y) + sk_pls = SkPLS(n_components=n_components, scale=False) # Do not rescale again. + np_pls_alg_1 = NpPLS(algorithm=1) + np_pls_alg_2 = NpPLS(algorithm=2) + jax_pls_alg_1 = JAX_Alg_1() + jax_pls_alg_2 = JAX_Alg_2() + + assert Y.shape[1] > 1 + assert Y.shape[1] > X.shape[1] + + sk_msg = "Y residual is constant at iteration" + with pytest.warns(UserWarning, match=sk_msg): + sk_pls.fit(X=X, Y=Y) + assert_allclose(sk_pls.x_rotations_, 0) + + msg = "Weight is close to zero." + with pytest.warns(UserWarning, match=msg): + np_pls_alg_1.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_1.R, 0) + with pytest.warns(UserWarning, match=msg): + np_pls_alg_2.fit(X=X, Y=Y, A=n_components) + assert_allclose(np_pls_alg_2.R, 0) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_1.fit(X=jnp_X, Y=jnp_Y, A=n_components) + with pytest.warns(UserWarning, match=msg): + jax_pls_alg_2.fit(X=jnp_X, Y=jnp_Y, A=n_components) + + +if __name__ == "__main__": + tc = TestClass() + tc.test_pls_1() + tc.test_pls_2_m_less_k() + tc.test_pls_2_m_eq_k() + tc.test_pls_2_m_greater_k() + tc.test_sanity_check_pls_regression() + tc.test_sanity_check_pls_regression_constant_column_Y() + tc.test_pls_1_constant_y() + tc.test_pls_2_m_less_k_constant_y() + tc.test_pls_2_m_eq_k_constant_y() + tc.test_pls_2_m_greater_k_constant_y() + # TODO: Check that results are consistent across CPU and GPU implementations. # TODO: Check that cross validation results match those achieved by SkLearn. # TODO: Implement general purpose cross validation for GPU algorithms. - # TODO: For this purpose, also implement general preprocessing where a user can pass a function that takes (X_train, Y_train, X_val, Y_val), peforms whatever operations and then returns processed arrays of the same type. +# TODO: For this purpose, also implement general preprocessing where a user can pass a function that takes (X_train, Y_train, X_val, Y_val), peforms whatever operations and then returns processed arrays of the same type # TODO: Use pytest.warns as context manager. -# TODO: Implement constant Y test from SkLearn's test suite. \ No newline at end of file +# TODO: Implement constant Y test from SkLearn's test suite.