Skip to content

Commit

Permalink
Compare std deviation from 0 with float epsilon instead of exact equa…
Browse files Browse the repository at this point in the history
…lity between std and 0 (#35)
  • Loading branch information
Sm00thix authored Jul 7, 2024
1 parent b27201b commit 0ddea86
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull_request_test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-12]
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-12]
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion ikpls/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.3"
__version__ = "1.2.4"
11 changes: 6 additions & 5 deletions ikpls/fast_cross_validation/numpy_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
self.scale_Y = scale_Y
self.algorithm = algorithm
self.dtype = dtype
self.eps = np.finfo(dtype).eps
self.name = f"Improved Kernel PLS Algorithm #{algorithm}"
if self.algorithm not in [1, 2]:
raise ValueError(
Expand Down Expand Up @@ -269,7 +270,7 @@ def _stateless_fit(
+ train_sum_sq_X
)
)
training_X_std[training_X_std == 0] = 1
training_X_std[np.abs(training_X_std) <= self.eps] = 1

# Compute the training set standard deviations for Y
if self.scale_Y:
Expand All @@ -289,7 +290,7 @@ def _stateless_fit(
+ train_sum_sq_Y
)
)
training_Y_std[training_Y_std == 0] = 1
training_Y_std[np.abs(training_Y_std) <= self.eps] = 1

# Subtract the validation set's contribution from the total XTY
training_XTY = self.XTY - validation_X.T @ validation_Y
Expand Down Expand Up @@ -341,7 +342,7 @@ def _stateless_fit(
# Step 2
if self.M == 1:
norm = la.norm(training_XTY, ord=2)
if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0):
if np.isclose(norm, 0, atol=self.eps, rtol=0):
self._weight_warning(i)
break
w = training_XTY / norm
Expand All @@ -352,15 +353,15 @@ def _stateless_fit(
q = eig_vecs[:, -1:]
w = training_XTY @ q
norm = la.norm(w)
if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0):
if np.isclose(norm, 0, atol=self.eps, rtol=0):
self._weight_warning(i)
break
w = w / norm
else:
training_XTYYTX = training_XTY @ training_XTY.T
eig_vals, eig_vecs = la.eigh(training_XTYYTX)
norm = eig_vals[-1]
if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0):
if np.isclose(norm, 0, atol=self.eps, rtol=0):
self._weight_warning(i)
break
w = eig_vecs[:, -1:]
Expand Down
3 changes: 2 additions & 1 deletion ikpls/jax_ikpls_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
self.scale_Y = scale_Y
self.copy = copy
self.dtype = dtype
self.eps = jnp.finfo(self.dtype).eps
self.reverse_differentiable = reverse_differentiable
self.verbose = verbose
self.name = "Improved Kernel PLS Algorithm"
Expand Down Expand Up @@ -236,7 +237,7 @@ def get_std(self, A: ArrayLike):
print(f"get_stds for {self.name} will be JIT compiled...")

A_std = jnp.std(A, axis=0, dtype=self.dtype, keepdims=True, ddof=1)
A_std = jnp.where(A_std == 0, 1, A_std)
A_std = jnp.where(jnp.abs(A_std) <= self.eps, 1, A_std)
return A_std

@partial(jax.jit, static_argnums=(0, 3, 4, 5, 6, 7))
Expand Down
11 changes: 6 additions & 5 deletions ikpls/numpy_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self.scale_Y = scale_Y
self.copy = copy
self.dtype = dtype
self.eps = np.finfo(dtype).eps
self.name = f"Improved Kernel PLS Algorithm #{algorithm}"
if self.algorithm not in [1, 2]:
raise ValueError(
Expand Down Expand Up @@ -207,12 +208,12 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None:

if self.scale_X:
self.X_std = X.std(axis=0, ddof=1, dtype=self.dtype, keepdims=True)
self.X_std[self.X_std == 0] = 1
self.X_std[np.abs(self.X_std) <= self.eps] = 1
X /= self.X_std

if self.scale_Y:
self.Y_std = Y.std(axis=0, ddof=1, dtype=self.dtype, keepdims=True)
self.Y_std[self.Y_std == 0] = 1
self.Y_std[np.abs(self.Y_std) <= self.eps] = 1
Y /= self.Y_std

N, K = X.shape
Expand Down Expand Up @@ -246,7 +247,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None:
# Step 2
if M == 1:
norm = la.norm(XTY, ord=2)
if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0):
if np.isclose(norm, 0, atol=self.eps, rtol=0):
self._weight_warning(i)
break
w = XTY / norm
Expand All @@ -257,15 +258,15 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None:
q = eig_vecs[:, -1:]
w = XTY @ q
norm = la.norm(w)
if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0):
if np.isclose(norm, 0, atol=self.eps, rtol=0):
self._weight_warning(i)
break
w = w / norm
else:
XTYYTX = XTY @ XTY.T
eig_vals, eig_vecs = la.eigh(XTYYTX)
norm = eig_vals[-1]
if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0):
if np.isclose(norm, 0, atol=self.eps, rtol=0):
self._weight_warning(i)
break
w = eig_vecs[:, -1:]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ikpls"
version = "1.2.3"
version = "1.2.4"
description = "Improved Kernel PLS and Fast Cross-Validation."
authors = ["Sm00thix <oleemail@icloud.com>"]
maintainers = ["Sm00thix <oleemail@icloud.com>"]
Expand Down
5 changes: 4 additions & 1 deletion tests/test_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,10 @@ def check_cpu_gpu_equality(
if n_good_components == -1:
n_good_components = np_pls_alg_1.A

atol = 0
try:
atol = np.finfo(np_pls_alg_1.dtype).eps
except AttributeError:
atol = 0
rtol = 1e-4
# Regression matrices
assert_allclose(
Expand Down

0 comments on commit 0ddea86

Please sign in to comment.