-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
846 additions
and
163 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,66 @@ | ||
# IKPLS | ||
Fast CPU and GPU Python implementations of Improved Kernel PLS by Dayal and MacGregor (1997). | ||
# Improved Kernel PLS | ||
Fast CPU, GPU, and TPU Python implementations of Improved Kernel PLS Algorithm #1 and Algorithm #2 by Dayal and MacGregor[^1]. Improved Kernel PLS has been shown to be both fast[^2] and numerically stable[^3]. | ||
The CPU implementations are made using NumPy[^4] and subclass BaseEstimator from scikit-learn[^5] allowing integration into scikit-learn's ecosystem of machine learning algorithms and pipelines. For example, the CPU implementations can be used with scikit-learn's [`cross_validate`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html). | ||
The GPU and TPU implementations are made using Google's JAX[^6]. While allowing CPU, GPU, and TPU execution, automatic differentiation is also supported by JAX. This implies that the JAX implementations can be used together with deep learning approaches as the PLS fit is differentiable. | ||
|
||
[^1]: [Dayal, B. S., & MacGregor, J. F. (1997). Improved PLS algorithms. Journal of Chemometrics: A Journal of the Chemometrics Society, 11(1), 73-85.](https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23) | ||
[^2]: [Alin, A. (2009). Comparison of PLS algorithms when number of objects is much larger than number of variables. Statistical papers, 50, 711-720.](https://link.springer.com/content/pdf/10.1007/s00362-009-0251-7.pdf) | ||
[^3]: [Andersson, M. (2009). A comparison of nine PLS1 algorithms. Journal of Chemometrics: A Journal of the Chemometrics Society, 23(10), 518-529.](https://analyticalsciencejournals.onlinelibrary.wiley.com/doi/pdf/10.1002/cem.1248?) | ||
[^4]: [NumPy.](https://numpy.org/) | ||
[^5]: [scikit-learn.](https://scikit-learn.org/stable/) | ||
[^6]: [JAX.](https://jax.readthedocs.io/en/latest/) | ||
|
||
|
||
## Pre-requisites | ||
The JAX implementations support running on both CPU and GPU. To use the GPU, follow the instructions from the [JAX Installation Guide](https://jax.readthedocs.io/en/latest/installation.html). | ||
To ensure that JAX implementations use Float64, set the environment variable JAX_ENABLE_X64=True as per the [Current Gotchas](https://github.com/google/jax#current-gotchas). | ||
The JAX implementations support running on both CPU, GPU, and TPU. To use the GPU or TPU, follow the instructions from the [JAX Installation Guide](https://jax.readthedocs.io/en/latest/installation.html). | ||
To ensure that JAX implementations use Float64, set the environment variable JAX_ENABLE_X64=True as per the [Current Gotchas](https://github.com/google/jax#current-gotchas). | ||
|
||
## Installation | ||
* Install the package for Python3 using the following command: | ||
`$ pip3 install ikpls` | ||
* Now you can import the NumPy and JAX implementations with: | ||
```python | ||
from ikpls.numpy_ikpls import PLS as NpPLS | ||
from ikpls.jax_ikpls_alg_1 import PLS as JAXPLS_Alg_1 | ||
from ikpls.jax_ikpls_alg_2 import PLS as JAXPLS_Alg_2 | ||
``` | ||
|
||
## Quick Start | ||
### Use the ikpls package for PLS modelling | ||
```python | ||
from ikpls.numpy_ikpls import PLS | ||
import numpy as np | ||
|
||
N = 100 # Number of samples. | ||
K = 50 # Number of features. | ||
M = 10 # Number of targets. | ||
A = 20 # Number of latent variables (PLS components). | ||
|
||
# Using float64 is important for numerical stability. | ||
X = np.random.uniform(size=(N, K)).astype(np.float64) | ||
Y = np.random.uniform(size=(N, M)).astype(np.float64) | ||
|
||
# The other PLS algorithms and implementations have the same interface for fit() and predict(). | ||
np_ikpls_alg_1 = PLS(algorithm=1) | ||
np_ikpls_alg_1.fit(X, Y, A) | ||
|
||
y_pred = np_ikpls_alg_1.predict(X) # Has shape (A, N, M) = (20, 100, 10). Contains a prediction for all possible number of components up to and including A. | ||
y_pred_20_components = np_ikpls_alg_1.predict(X, n_components=20) # Has shape (N, M) = (100, 10). | ||
(y_pred_20_components == y_pred[19]).all() # True | ||
|
||
# The internal model parameters can be accessed as follows: | ||
np_ikpls_alg_1.B # Regression coefficients tensor of shape (A, K, M) = (20, 50, 10). | ||
np_ikpls_alg_1.W # X weights matrix of shape (K, A) = (50, 20). | ||
np_ikpls_alg_1.P # X loadings matrix of shape (K, A) = (50, 20). | ||
np_ikpls_alg_1.Q # Y loadings matrix of shape (M, A) = (10, 20). | ||
np_ikpls_alg_1.R # X rotations matrix of shape (K, A) = (50, 20). | ||
np_ikpls_alg_1.T # X scores matrix of shape (N, A) = (100, 20). This is only computed for IKPLS Algorithm #1. | ||
``` | ||
|
||
## Examples | ||
In [examples](examples/) you will find: | ||
* [Example](examples/fit_predict_numpy.py) of fitting and predicting with the NumPy implementations. | ||
* [Example](examples/fit_predict_jax.py) of fitting and predicting with the JAX implementations. | ||
* [Example](examples/cross_val_numpy.py) of cross validating with the NumPy implementations. | ||
* [Example](examples/cross_val_jax.py) of cross validating with the JAX implementations. | ||
* [Example](examples/gradient_jax.py) of computing the gradient of a preprocessing filter with respect to the RMSE between the target value and the value predicted by PLS after fitting. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from ikpls.jax_ikpls_alg_1 import ( | ||
PLS, | ||
) # For this example, we will use IKPLS Algorithm #1. The interface for IKPLS Algorithm #2 is identical. | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import jax | ||
from typing import Tuple | ||
|
||
|
||
# Function to apply mean centering to X and Y based on training data. | ||
def cross_val_preprocessing( | ||
X_train: jnp.ndarray, | ||
Y_train: jnp.ndarray, | ||
X_val: jnp.ndarray, | ||
Y_val: jnp.ndarray, | ||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: | ||
print( | ||
"Preprocessing function will be JIT compiled..." | ||
) # The internals of .cv() in JAX are JIT compiled. That includes the preprocessing function. | ||
x_mean = X_train.mean(axis=0, keepdims=True) | ||
X_train -= x_mean | ||
X_val -= x_mean | ||
y_mean = Y_train.mean(axis=0, keepdims=True) | ||
Y_train -= y_mean | ||
Y_val -= y_mean | ||
return X_train, Y_train, X_val, Y_val | ||
|
||
|
||
def mse_per_component_and_best_components( | ||
Y_true: jnp.ndarray, Y_pred: jnp.ndarray | ||
) -> jnp.ndarray: | ||
print( | ||
"Metric function will be JIT compiled..." | ||
) # The internals of .cv() in JAX are JIT compiled. That includes the metric function. | ||
# Y_true has shape (N, M), Y_pred has shape (A, N, M). | ||
e = Y_true - Y_pred # Shape (A, N, M) | ||
se = e**2 # Shape (A, N, M) | ||
mse = jnp.mean(se, axis=-2) # Shape (A, M) | ||
best_num_components = jnp.argmin(mse, axis=0) + 1 # Shape (M,) | ||
return (mse, best_num_components) | ||
|
||
|
||
if __name__ == "__main__": | ||
""" | ||
NOTE: Every time a training or validation split has a different size from the previously encountered one, recompilation will occur. | ||
This is because the JIT compiler must generate a new function for each unique input shape. | ||
Thus, if all splits have the same shape, JIT compilation happens only once. | ||
""" | ||
N = 100 # Number of samples. | ||
K = 50 # Number of features. | ||
M = 10 # Number of targets. | ||
A = 20 # Number of latent variables (PLS components). | ||
splits = np.arange(100) % 5 # Randomly assign each sample to one of 5 splits. | ||
|
||
# Using float64 is important for numerical stability. | ||
X = np.random.uniform(size=(N, K)).astype(np.float64) | ||
Y = np.random.uniform(size=(N, M)).astype(np.float64) | ||
|
||
jax_pls_alg_1 = PLS(verbose=True) | ||
|
||
metric_names = ["mse", "best_num_components"] | ||
metric_values_dict = jax_pls_alg_1.cv( | ||
X, | ||
Y, | ||
A, | ||
cv_splits=splits, | ||
preprocessing_function=cross_val_preprocessing, | ||
metric_function=mse_per_component_and_best_components, | ||
metric_names=metric_names, | ||
) | ||
|
||
""" | ||
list of length 5 where each element is an array of shape (A, M) = (20, 10) | ||
corresponding to the mse output of mse_per_component_and_best_components for each split. | ||
""" | ||
mse_for_each_split = metric_values_dict["mse"] | ||
mse_for_each_split = np.array( | ||
mse_for_each_split | ||
) # shape (n_splits, A, M) = (5, 20, 10) | ||
|
||
""" | ||
list of length 5 where each element is an array of shape (M,) = (10,) | ||
corresponding to the best_num_components output of mse_per_component_and_best_components for each split. | ||
""" | ||
best_num_components_for_each_split = metric_values_dict["best_num_components"] | ||
best_num_components_for_each_split = np.array( | ||
best_num_components_for_each_split | ||
) # shape (n_splits, M) = (5, 10) | ||
|
||
""" | ||
# The -1 in the index is due to the fact that mse_for_each_split is 0-indexed but the number of components go from 1 to A. | ||
This could also have been implemented using jax.numpy in mse_per_component_and_best_components directly as part of the metric function. | ||
""" | ||
best_mse_for_each_split = np.amin(mse_for_each_split, axis=-2) # (n_splits, M) shape (5, 10) | ||
equivalent_best_mse_for_each_split = np.array( | ||
[ | ||
mse_for_each_split[ | ||
i, best_num_components_for_each_split[i] - 1, np.arange(M) | ||
] | ||
for i in range(len(best_num_components_for_each_split)) | ||
] | ||
) # shape (n_splits, M) = (5, 10) | ||
(best_mse_for_each_split == equivalent_best_mse_for_each_split).all() # True | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from typing import Union | ||
from ikpls.numpy_ikpls import PLS | ||
import numpy as np | ||
import numpy.typing as npt | ||
from sklearn.model_selection import cross_validate | ||
|
||
|
||
class PLSWithPreprocessing( | ||
PLS | ||
): # We can simply inherit from the numpy implementation to override the fit and predict methods to include preprocessing. | ||
def __init__(self, algorithm: int = 1, dtype: np.float_ = np.float64) -> None: | ||
super().__init__(algorithm, dtype) | ||
|
||
def fit( | ||
self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int | ||
) -> None: # Override the fit method to include mean centering of X and Y. | ||
self.X_mean = np.mean(X, axis=0) | ||
self.Y_mean = np.mean(Y, axis=0) | ||
X -= self.X_mean | ||
Y -= self.Y_mean | ||
return super().fit(X, Y, A) | ||
|
||
def predict( # Override the predict method to include mean centering of X and Y based on values encountered in fit. | ||
self, X: npt.ArrayLike, A: Union[None, int] = None | ||
) -> npt.NDArray[np.float_]: | ||
return super().predict(X - self.X_mean, A) + self.Y_mean | ||
|
||
|
||
def cv_splitter( | ||
splits: npt.NDArray, | ||
): # Splits is a 1D array of integers indicating the split number for each sample. | ||
uniq_splits = np.unique(splits) | ||
for split in uniq_splits: | ||
train_idxs = np.nonzero(splits != split)[0] | ||
val_idxs = np.nonzero(splits == split)[0] | ||
yield train_idxs, val_idxs | ||
|
||
|
||
def mse_for_each_target(estimator, X, Y_true, **kwargs): | ||
# We must return a dict of singular values. Let's choose the number of components that achieves the lowest MSE value for each target and return both MSE and the number of components. | ||
# Y_true has shape (N, M) | ||
Y_pred = estimator.predict(X, **kwargs) # Shape (A, N, M) | ||
e = Y_true - Y_pred # Shape (A, N, M) | ||
se = e**2 # Shape (A, N, M) | ||
mse = np.mean(se, axis=-2) # Compute the mean over samples. Shape (A, M). | ||
row_idxs = np.argmin( | ||
mse, axis=0 | ||
) # The number of components that minimizes the MSE for each target. Shape (M,). | ||
lowest_mses = mse[ | ||
row_idxs, np.arange(mse.shape[1]) | ||
] # The lowest MSE for each target. Shape (M,). | ||
num_components = ( | ||
row_idxs + 1 | ||
) # Indices are 0-indexed but number of components is 1-indexed. | ||
mse_names = [ | ||
f"lowest_mse_target_{i}" for i in range(lowest_mses.shape[0]) | ||
] # List of names for the lowest MSE values. | ||
num_components_names = [ # List of names for the number of components that achieves the lowest MSE for each target. | ||
f"num_components_lowest_mse_target_{i}" for i in range(lowest_mses.shape[0]) | ||
] | ||
all_names = mse_names + num_components_names # List of all names. | ||
all_values = np.concatenate((lowest_mses, num_components)) # Array of all values. | ||
return dict(zip(all_names, all_values)) | ||
|
||
|
||
if __name__ == "__main__": | ||
N = 100 # Number of samples. | ||
K = 50 # Number of features. | ||
M = 10 # Number of targets. | ||
A = 20 # Number of latent variables (PLS components). | ||
splits = np.random.randint( | ||
0, 5, size=N | ||
) # Randomly assign each sample to one of 5 splits. | ||
|
||
# Using float64 is important for numerical stability. | ||
X = np.random.uniform(size=(N, K)).astype(np.float64) | ||
Y = np.random.uniform(size=(N, M)).astype(np.float64) | ||
|
||
np_pls_alg_1 = PLSWithPreprocessing(algorithm=1) # For this example, we will use IKPLS Algorithm #1. The interface for IKPLS Algorithm #2 is identical. | ||
fit_params = {"A": A} | ||
np_pls_alg_1_results = cross_validate( | ||
np_pls_alg_1, | ||
X, | ||
Y, | ||
cv=cv_splitter(splits), | ||
scoring=mse_for_each_target, # We want to return the MSE for each target and the number of components that achieves the lowest MSE for each target. | ||
fit_params=fit_params, # We want to pass the number of components to the fit method. | ||
return_estimator=False, # We don't need the estimators themselves, just the MSEs and the best number of components. | ||
n_jobs=-1, # Use all available CPU cores. | ||
) | ||
|
||
lowest_val_mses = np.array( | ||
[np_pls_alg_1_results[f"test_lowest_mse_target_{i}"] for i in range(M)] | ||
) # Shape (M, splits) = (10, 5). Lowest MSE for each target for each split. | ||
|
||
best_num_components = np.array( | ||
[ | ||
np_pls_alg_1_results[f"test_num_components_lowest_mse_target_{i}"] | ||
for i in range(M) | ||
] | ||
) # Shape (M, splits) = (10, 5). Number of components that achieves the lowest MSE for each target for each split. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from ikpls.jax_ikpls_alg_1 import PLS | ||
import numpy as np | ||
|
||
if __name__ == '__main__': | ||
N = 100 # Number of samples. | ||
K = 50 # Number of features. | ||
M = 10 # Number of targets. | ||
A = 20 # Number of latent variables (PLS components). | ||
|
||
X = np.random.uniform(size=(N, K)).astype(np.float64) | ||
Y = np.random.uniform(size=(N, M)).astype(np.float64) | ||
|
||
jax_ikpls_alg_1 = PLS() | ||
jax_ikpls_alg_1.fit(X, Y, A) | ||
|
||
y_pred = jax_ikpls_alg_1.predict(X) # Has shape (A, N, M) = (20, 100, 10). Contains a prediction for all possible number of components up to and including A. | ||
y_pred_20_components = jax_ikpls_alg_1.predict(X, n_components=20) # Has shape (N, M) = (100, 10). | ||
np.allclose(y_pred_20_components, y_pred[19], atol=0, rtol=1e14) # True. Exact equality might not hold due to numerical differences. | ||
|
||
# The internal model parameters can be accessed as follows: | ||
jax_ikpls_alg_1.B # Regression coefficients tensor of shape (A, K, M) = (20, 50, 10). | ||
jax_ikpls_alg_1.W # X weights matrix of shape (K, A) = (50, 20). | ||
jax_ikpls_alg_1.P # X loadings matrix of shape (K, A) = (50, 20). | ||
jax_ikpls_alg_1.Q # Y loadings matrix of shape (M, A) = (10, 20). | ||
jax_ikpls_alg_1.R # X rotations matrix of shape (K, A) = (50, 20). | ||
jax_ikpls_alg_1.T # X scores matrix of shape (N, A) = (100, 20). This is only computed for IKPLS Algorithm #1. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from ikpls.numpy_ikpls import PLS | ||
import numpy as np | ||
|
||
if __name__ == "__main__": | ||
N = 100 # Number of samples. | ||
K = 50 # Number of features. | ||
M = 10 # Number of targets. | ||
A = 20 # Number of latent variables (PLS components). | ||
|
||
# Using float64 is important for numerical stability. | ||
X = np.random.uniform(size=(N, K)).astype(np.float64) | ||
Y = np.random.uniform(size=(N, M)).astype(np.float64) | ||
|
||
# The other PLS algorithms and implementations have the same interface for fit() and predict(). | ||
np_ikpls_alg_1 = PLS(algorithm=1) | ||
np_ikpls_alg_1.fit(X, Y, A) | ||
|
||
y_pred = np_ikpls_alg_1.predict( | ||
X | ||
) # Has shape (A, N, M) = (20, 100, 10). Contains a prediction for all possible number of components up to and including A. | ||
y_pred_20_components = np_ikpls_alg_1.predict( | ||
X, n_components=20 | ||
) # Has shape (N, M) = (100, 10). | ||
(y_pred_20_components == y_pred[19]).all() # True | ||
|
||
# The internal model parameters can be accessed as follows: | ||
np_ikpls_alg_1.B # Regression coefficients tensor of shape (A, K, M) = (20, 50, 10). | ||
np_ikpls_alg_1.W # X weights matrix of shape (K, A) = (50, 20). | ||
np_ikpls_alg_1.P # X loadings matrix of shape (K, A) = (50, 20). | ||
np_ikpls_alg_1.Q # Y loadings matrix of shape (M, A) = (10, 20). | ||
np_ikpls_alg_1.R # X rotations matrix of shape (K, A) = (50, 20). | ||
np_ikpls_alg_1.T # X scores matrix of shape (N, A) = (100, 20). This is only computed for IKPLS Algorithm #1. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# TODO: TBA For now, look for check_gradient_pls() in ../tests/test_ikpls.py if you can not wait to use this feature. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +0,0 @@ | ||
from . import jax_ikpls_alg_1, jax_ikpls_alg_2, jax_ikpls_base, numpy_ikpls | ||
|
||
__all__ = ["jax_ikpls_alg_1", "jax_ikpls_alg_2", "jax_ikpls_base", "numpy_ikpls"] | ||
Oops, something went wrong.