Skip to content

Commit

Permalink
Added small tolerance to account for numerical errors in check_cross_…
Browse files Browse the repository at this point in the history
…val_pls
  • Loading branch information
Sm00thix committed Nov 1, 2023
1 parent 35415bc commit f7e8590
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 50 deletions.
25 changes: 0 additions & 25 deletions algorithms/jax_ikpls_alg_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,31 +65,6 @@ def _main_loop_body(
XTY = self._step_5(XTY, p, q, tTt)
return XTY, w, p, q, r, t

# @partial(jax.jit, static_argnums=(0, 4, 5))
# def _main_loop_body(
# self,
# i: int,
# X: jnp.ndarray,
# XTY: jnp.ndarray,
# M: int,
# K: int,
# P: jnp.ndarray,
# R: jnp.ndarray,
# ) -> Tuple[
# jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray
# ]:
# print("Tracing loop body...")
# # step 2
# w, norm = self._step_2(XTY, M, K)
# host_callback.id_tap(self.weight_warning, [i, norm])
# # step 3
# r = self._step_3(i, w, P, R)
# # step 4
# tTt, p, q, t = self._step_4(X, XTY, r)
# # step 5
# XTY = self._step_5(XTY, p, q, tTt)
# return XTY, w, p, q, r, t

def fit(self, X: jnp.ndarray, Y: jnp.ndarray, A: int) -> None:
self.B, _W, _P, _Q, _R, _T = self.stateless_fit(X, Y, A)
self.W = _W.T
Expand Down
23 changes: 0 additions & 23 deletions algorithms/jax_ikpls_alg_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,6 @@ def _step_4(
q = (r.T @ XTY).T / tTt
return tTt, p, q

# @partial(jax.jit, static_argnums=(0, 4, 5))
# def _main_loop_body(
# self,
# i: int,
# XTX: jnp.ndarray,
# XTY: jnp.ndarray,
# M: int,
# K: int,
# P: jnp.ndarray,
# R: jnp.ndarray,
# ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
# print("Tracing loop body...")
# # step 2
# w, norm = self._step_2(XTY, M, K)
# host_callback.id_tap(self.weight_warning, [i, norm])
# # step 3
# r = self._step_3(i, w, P, R)
# # step 4
# tTt, p, q = self._step_4(XTX, XTY, r)
# # step 5
# XTY = self._step_5(XTY, p, q, tTt)
# return XTY, w, p, q, r

@partial(jax.jit, static_argnums=(0, 1, 5, 6))
def _main_loop_body(
self,
Expand Down
6 changes: 6 additions & 0 deletions tests/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@


def load_csv():
"""
Loads a csv-file with 26617 rows and 11 columns. The columns represent ground truth values for 8 different grain varieties, protein, moisture, and an assignment to a dataset split.
"""
csv_url = GITHUB_DATADIR + "dataset.csv"
columns = [
"Rye_Midsummer",
Expand All @@ -27,6 +30,9 @@ def load_csv():


def load_spectra():
"""
Loads 26617 near-infrared (NIR) spectra with 102 wavelength channels and transforms them from reflectance to pseudo absorbance.
"""
spectra_url = GITHUB_DATADIR + "spectra.npz"
resp = urlopen(spectra_url)
resp_byte_array = resp.read()
Expand Down
6 changes: 4 additions & 2 deletions tests/test_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def jax_rmse_per_component(y_true, y_pred):
] + sk_models[i].intercept_
sk_preds[i] = sk_pred
assert_allclose(
sk_pred[-1], sk_models[i].predict(X), atol=0, rtol=0
sk_pred[-1], sk_models[i].predict(X), atol=0, rtol=1e-14
) # Sanity check. SkPLS also uses the maximum number of components in its predict method.

# Compute RMSE on the validation predictions
Expand Down Expand Up @@ -1609,4 +1609,6 @@ def test_cross_val_pls_2_m_greater_k(self):
tc.test_cross_val_pls_1()
tc.test_cross_val_pls_2_m_less_k()
tc.test_cross_val_pls_2_m_eq_k()
tc.test_cross_val_pls_2_m_greater_k()
tc.test_cross_val_pls_2_m_greater_k()

# TODO: Doc strings for tests and algorithms.

0 comments on commit f7e8590

Please sign in to comment.