diff --git a/algorithms/jax_ikpls_alg_1.py b/algorithms/jax_ikpls_alg_1.py index 4d8835e..59d9bf8 100644 --- a/algorithms/jax_ikpls_alg_1.py +++ b/algorithms/jax_ikpls_alg_1.py @@ -65,31 +65,6 @@ def _main_loop_body( XTY = self._step_5(XTY, p, q, tTt) return XTY, w, p, q, r, t - # @partial(jax.jit, static_argnums=(0, 4, 5)) - # def _main_loop_body( - # self, - # i: int, - # X: jnp.ndarray, - # XTY: jnp.ndarray, - # M: int, - # K: int, - # P: jnp.ndarray, - # R: jnp.ndarray, - # ) -> Tuple[ - # jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray - # ]: - # print("Tracing loop body...") - # # step 2 - # w, norm = self._step_2(XTY, M, K) - # host_callback.id_tap(self.weight_warning, [i, norm]) - # # step 3 - # r = self._step_3(i, w, P, R) - # # step 4 - # tTt, p, q, t = self._step_4(X, XTY, r) - # # step 5 - # XTY = self._step_5(XTY, p, q, tTt) - # return XTY, w, p, q, r, t - def fit(self, X: jnp.ndarray, Y: jnp.ndarray, A: int) -> None: self.B, _W, _P, _Q, _R, _T = self.stateless_fit(X, Y, A) self.W = _W.T diff --git a/algorithms/jax_ikpls_alg_2.py b/algorithms/jax_ikpls_alg_2.py index 8034bbf..d7d0816 100644 --- a/algorithms/jax_ikpls_alg_2.py +++ b/algorithms/jax_ikpls_alg_2.py @@ -41,29 +41,6 @@ def _step_4( q = (r.T @ XTY).T / tTt return tTt, p, q - # @partial(jax.jit, static_argnums=(0, 4, 5)) - # def _main_loop_body( - # self, - # i: int, - # XTX: jnp.ndarray, - # XTY: jnp.ndarray, - # M: int, - # K: int, - # P: jnp.ndarray, - # R: jnp.ndarray, - # ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: - # print("Tracing loop body...") - # # step 2 - # w, norm = self._step_2(XTY, M, K) - # host_callback.id_tap(self.weight_warning, [i, norm]) - # # step 3 - # r = self._step_3(i, w, P, R) - # # step 4 - # tTt, p, q = self._step_4(XTX, XTY, r) - # # step 5 - # XTY = self._step_5(XTY, p, q, tTt) - # return XTY, w, p, q, r - @partial(jax.jit, static_argnums=(0, 1, 5, 6)) def _main_loop_body( self, diff --git a/tests/load_data.py b/tests/load_data.py index 946191e..9f3493d 100644 --- a/tests/load_data.py +++ b/tests/load_data.py @@ -7,6 +7,9 @@ def load_csv(): + """ + Loads a csv-file with 26617 rows and 11 columns. The columns represent ground truth values for 8 different grain varieties, protein, moisture, and an assignment to a dataset split. + """ csv_url = GITHUB_DATADIR + "dataset.csv" columns = [ "Rye_Midsummer", @@ -27,6 +30,9 @@ def load_csv(): def load_spectra(): + """ + Loads 26617 near-infrared (NIR) spectra with 102 wavelength channels and transforms them from reflectance to pseudo absorbance. + """ spectra_url = GITHUB_DATADIR + "spectra.npz" resp = urlopen(spectra_url) resp_byte_array = resp.read() diff --git a/tests/test_consistency.py b/tests/test_consistency.py index 46e1630..6f162f5 100644 --- a/tests/test_consistency.py +++ b/tests/test_consistency.py @@ -1352,7 +1352,7 @@ def jax_rmse_per_component(y_true, y_pred): ] + sk_models[i].intercept_ sk_preds[i] = sk_pred assert_allclose( - sk_pred[-1], sk_models[i].predict(X), atol=0, rtol=0 + sk_pred[-1], sk_models[i].predict(X), atol=0, rtol=1e-14 ) # Sanity check. SkPLS also uses the maximum number of components in its predict method. # Compute RMSE on the validation predictions @@ -1609,4 +1609,6 @@ def test_cross_val_pls_2_m_greater_k(self): tc.test_cross_val_pls_1() tc.test_cross_val_pls_2_m_less_k() tc.test_cross_val_pls_2_m_eq_k() - tc.test_cross_val_pls_2_m_greater_k() \ No newline at end of file + tc.test_cross_val_pls_2_m_greater_k() + + # TODO: Doc strings for tests and algorithms. \ No newline at end of file