Skip to content

Commit

Permalink
Dev (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sm00thix authored Nov 14, 2023
1 parent 960e6b2 commit 9bf5e5e
Show file tree
Hide file tree
Showing 25 changed files with 846 additions and 163 deletions.
68 changes: 64 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,66 @@
# IKPLS
Fast CPU and GPU Python implementations of Improved Kernel PLS by Dayal and MacGregor (1997).
# Improved Kernel PLS
Fast CPU, GPU, and TPU Python implementations of Improved Kernel PLS Algorithm #1 and Algorithm #2 by Dayal and MacGregor[^1]. Improved Kernel PLS has been shown to be both fast[^2] and numerically stable[^3].
The CPU implementations are made using NumPy[^4] and subclass BaseEstimator from scikit-learn[^5] allowing integration into scikit-learn's ecosystem of machine learning algorithms and pipelines. For example, the CPU implementations can be used with scikit-learn's [`cross_validate`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html).
The GPU and TPU implementations are made using Google's JAX[^6]. While allowing CPU, GPU, and TPU execution, automatic differentiation is also supported by JAX. This implies that the JAX implementations can be used together with deep learning approaches as the PLS fit is differentiable.

[^1]: [Dayal, B. S., & MacGregor, J. F. (1997). Improved PLS algorithms. Journal of Chemometrics: A Journal of the Chemometrics Society, 11(1), 73-85.](https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23)
[^2]: [Alin, A. (2009). Comparison of PLS algorithms when number of objects is much larger than number of variables. Statistical papers, 50, 711-720.](https://link.springer.com/content/pdf/10.1007/s00362-009-0251-7.pdf)
[^3]: [Andersson, M. (2009). A comparison of nine PLS1 algorithms. Journal of Chemometrics: A Journal of the Chemometrics Society, 23(10), 518-529.](https://analyticalsciencejournals.onlinelibrary.wiley.com/doi/pdf/10.1002/cem.1248?)
[^4]: [NumPy.](https://numpy.org/)
[^5]: [scikit-learn.](https://scikit-learn.org/stable/)
[^6]: [JAX.](https://jax.readthedocs.io/en/latest/)


## Pre-requisites
The JAX implementations support running on both CPU and GPU. To use the GPU, follow the instructions from the [JAX Installation Guide](https://jax.readthedocs.io/en/latest/installation.html).
To ensure that JAX implementations use Float64, set the environment variable JAX_ENABLE_X64=True as per the [Current Gotchas](https://github.com/google/jax#current-gotchas).
The JAX implementations support running on both CPU, GPU, and TPU. To use the GPU or TPU, follow the instructions from the [JAX Installation Guide](https://jax.readthedocs.io/en/latest/installation.html).
To ensure that JAX implementations use Float64, set the environment variable JAX_ENABLE_X64=True as per the [Current Gotchas](https://github.com/google/jax#current-gotchas).

## Installation
* Install the package for Python3 using the following command:
`$ pip3 install ikpls`
* Now you can import the NumPy and JAX implementations with:
```python
from ikpls.numpy_ikpls import PLS as NpPLS
from ikpls.jax_ikpls_alg_1 import PLS as JAXPLS_Alg_1
from ikpls.jax_ikpls_alg_2 import PLS as JAXPLS_Alg_2
```

## Quick Start
### Use the ikpls package for PLS modelling
```python
from ikpls.numpy_ikpls import PLS
import numpy as np

N = 100 # Number of samples.
K = 50 # Number of features.
M = 10 # Number of targets.
A = 20 # Number of latent variables (PLS components).

# Using float64 is important for numerical stability.
X = np.random.uniform(size=(N, K)).astype(np.float64)
Y = np.random.uniform(size=(N, M)).astype(np.float64)

# The other PLS algorithms and implementations have the same interface for fit() and predict().
np_ikpls_alg_1 = PLS(algorithm=1)
np_ikpls_alg_1.fit(X, Y, A)

y_pred = np_ikpls_alg_1.predict(X) # Has shape (A, N, M) = (20, 100, 10). Contains a prediction for all possible number of components up to and including A.
y_pred_20_components = np_ikpls_alg_1.predict(X, n_components=20) # Has shape (N, M) = (100, 10).
(y_pred_20_components == y_pred[19]).all() # True

# The internal model parameters can be accessed as follows:
np_ikpls_alg_1.B # Regression coefficients tensor of shape (A, K, M) = (20, 50, 10).
np_ikpls_alg_1.W # X weights matrix of shape (K, A) = (50, 20).
np_ikpls_alg_1.P # X loadings matrix of shape (K, A) = (50, 20).
np_ikpls_alg_1.Q # Y loadings matrix of shape (M, A) = (10, 20).
np_ikpls_alg_1.R # X rotations matrix of shape (K, A) = (50, 20).
np_ikpls_alg_1.T # X scores matrix of shape (N, A) = (100, 20). This is only computed for IKPLS Algorithm #1.
```

## Examples
In [examples](examples/) you will find:
* [Example](examples/fit_predict_numpy.py) of fitting and predicting with the NumPy implementations.
* [Example](examples/fit_predict_jax.py) of fitting and predicting with the JAX implementations.
* [Example](examples/cross_val_numpy.py) of cross validating with the NumPy implementations.
* [Example](examples/cross_val_jax.py) of cross validating with the JAX implementations.
* [Example](examples/gradient_jax.py) of computing the gradient of a preprocessing filter with respect to the RMSE between the target value and the value predicted by PLS after fitting.
Empty file added examples/__init__.py
Empty file.
104 changes: 104 additions & 0 deletions examples/cross_val_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from ikpls.jax_ikpls_alg_1 import (
PLS,
) # For this example, we will use IKPLS Algorithm #1. The interface for IKPLS Algorithm #2 is identical.
import jax.numpy as jnp
import numpy as np
import jax
from typing import Tuple


# Function to apply mean centering to X and Y based on training data.
def cross_val_preprocessing(
X_train: jnp.ndarray,
Y_train: jnp.ndarray,
X_val: jnp.ndarray,
Y_val: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
print(
"Preprocessing function will be JIT compiled..."
) # The internals of .cv() in JAX are JIT compiled. That includes the preprocessing function.
x_mean = X_train.mean(axis=0, keepdims=True)
X_train -= x_mean
X_val -= x_mean
y_mean = Y_train.mean(axis=0, keepdims=True)
Y_train -= y_mean
Y_val -= y_mean
return X_train, Y_train, X_val, Y_val


def mse_per_component_and_best_components(
Y_true: jnp.ndarray, Y_pred: jnp.ndarray
) -> jnp.ndarray:
print(
"Metric function will be JIT compiled..."
) # The internals of .cv() in JAX are JIT compiled. That includes the metric function.
# Y_true has shape (N, M), Y_pred has shape (A, N, M).
e = Y_true - Y_pred # Shape (A, N, M)
se = e**2 # Shape (A, N, M)
mse = jnp.mean(se, axis=-2) # Shape (A, M)
best_num_components = jnp.argmin(mse, axis=0) + 1 # Shape (M,)
return (mse, best_num_components)


if __name__ == "__main__":
"""
NOTE: Every time a training or validation split has a different size from the previously encountered one, recompilation will occur.
This is because the JIT compiler must generate a new function for each unique input shape.
Thus, if all splits have the same shape, JIT compilation happens only once.
"""
N = 100 # Number of samples.
K = 50 # Number of features.
M = 10 # Number of targets.
A = 20 # Number of latent variables (PLS components).
splits = np.arange(100) % 5 # Randomly assign each sample to one of 5 splits.

# Using float64 is important for numerical stability.
X = np.random.uniform(size=(N, K)).astype(np.float64)
Y = np.random.uniform(size=(N, M)).astype(np.float64)

jax_pls_alg_1 = PLS(verbose=True)

metric_names = ["mse", "best_num_components"]
metric_values_dict = jax_pls_alg_1.cv(
X,
Y,
A,
cv_splits=splits,
preprocessing_function=cross_val_preprocessing,
metric_function=mse_per_component_and_best_components,
metric_names=metric_names,
)

"""
list of length 5 where each element is an array of shape (A, M) = (20, 10)
corresponding to the mse output of mse_per_component_and_best_components for each split.
"""
mse_for_each_split = metric_values_dict["mse"]
mse_for_each_split = np.array(
mse_for_each_split
) # shape (n_splits, A, M) = (5, 20, 10)

"""
list of length 5 where each element is an array of shape (M,) = (10,)
corresponding to the best_num_components output of mse_per_component_and_best_components for each split.
"""
best_num_components_for_each_split = metric_values_dict["best_num_components"]
best_num_components_for_each_split = np.array(
best_num_components_for_each_split
) # shape (n_splits, M) = (5, 10)

"""
# The -1 in the index is due to the fact that mse_for_each_split is 0-indexed but the number of components go from 1 to A.
This could also have been implemented using jax.numpy in mse_per_component_and_best_components directly as part of the metric function.
"""
best_mse_for_each_split = np.amin(mse_for_each_split, axis=-2) # (n_splits, M) shape (5, 10)
equivalent_best_mse_for_each_split = np.array(
[
mse_for_each_split[
i, best_num_components_for_each_split[i] - 1, np.arange(M)
]
for i in range(len(best_num_components_for_each_split))
]
) # shape (n_splits, M) = (5, 10)
(best_mse_for_each_split == equivalent_best_mse_for_each_split).all() # True
pass
101 changes: 101 additions & 0 deletions examples/cross_val_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Union
from ikpls.numpy_ikpls import PLS
import numpy as np
import numpy.typing as npt
from sklearn.model_selection import cross_validate


class PLSWithPreprocessing(
PLS
): # We can simply inherit from the numpy implementation to override the fit and predict methods to include preprocessing.
def __init__(self, algorithm: int = 1, dtype: np.float_ = np.float64) -> None:
super().__init__(algorithm, dtype)

def fit(
self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int
) -> None: # Override the fit method to include mean centering of X and Y.
self.X_mean = np.mean(X, axis=0)
self.Y_mean = np.mean(Y, axis=0)
X -= self.X_mean
Y -= self.Y_mean
return super().fit(X, Y, A)

def predict( # Override the predict method to include mean centering of X and Y based on values encountered in fit.
self, X: npt.ArrayLike, A: Union[None, int] = None
) -> npt.NDArray[np.float_]:
return super().predict(X - self.X_mean, A) + self.Y_mean


def cv_splitter(
splits: npt.NDArray,
): # Splits is a 1D array of integers indicating the split number for each sample.
uniq_splits = np.unique(splits)
for split in uniq_splits:
train_idxs = np.nonzero(splits != split)[0]
val_idxs = np.nonzero(splits == split)[0]
yield train_idxs, val_idxs


def mse_for_each_target(estimator, X, Y_true, **kwargs):
# We must return a dict of singular values. Let's choose the number of components that achieves the lowest MSE value for each target and return both MSE and the number of components.
# Y_true has shape (N, M)
Y_pred = estimator.predict(X, **kwargs) # Shape (A, N, M)
e = Y_true - Y_pred # Shape (A, N, M)
se = e**2 # Shape (A, N, M)
mse = np.mean(se, axis=-2) # Compute the mean over samples. Shape (A, M).
row_idxs = np.argmin(
mse, axis=0
) # The number of components that minimizes the MSE for each target. Shape (M,).
lowest_mses = mse[
row_idxs, np.arange(mse.shape[1])
] # The lowest MSE for each target. Shape (M,).
num_components = (
row_idxs + 1
) # Indices are 0-indexed but number of components is 1-indexed.
mse_names = [
f"lowest_mse_target_{i}" for i in range(lowest_mses.shape[0])
] # List of names for the lowest MSE values.
num_components_names = [ # List of names for the number of components that achieves the lowest MSE for each target.
f"num_components_lowest_mse_target_{i}" for i in range(lowest_mses.shape[0])
]
all_names = mse_names + num_components_names # List of all names.
all_values = np.concatenate((lowest_mses, num_components)) # Array of all values.
return dict(zip(all_names, all_values))


if __name__ == "__main__":
N = 100 # Number of samples.
K = 50 # Number of features.
M = 10 # Number of targets.
A = 20 # Number of latent variables (PLS components).
splits = np.random.randint(
0, 5, size=N
) # Randomly assign each sample to one of 5 splits.

# Using float64 is important for numerical stability.
X = np.random.uniform(size=(N, K)).astype(np.float64)
Y = np.random.uniform(size=(N, M)).astype(np.float64)

np_pls_alg_1 = PLSWithPreprocessing(algorithm=1) # For this example, we will use IKPLS Algorithm #1. The interface for IKPLS Algorithm #2 is identical.
fit_params = {"A": A}
np_pls_alg_1_results = cross_validate(
np_pls_alg_1,
X,
Y,
cv=cv_splitter(splits),
scoring=mse_for_each_target, # We want to return the MSE for each target and the number of components that achieves the lowest MSE for each target.
fit_params=fit_params, # We want to pass the number of components to the fit method.
return_estimator=False, # We don't need the estimators themselves, just the MSEs and the best number of components.
n_jobs=-1, # Use all available CPU cores.
)

lowest_val_mses = np.array(
[np_pls_alg_1_results[f"test_lowest_mse_target_{i}"] for i in range(M)]
) # Shape (M, splits) = (10, 5). Lowest MSE for each target for each split.

best_num_components = np.array(
[
np_pls_alg_1_results[f"test_num_components_lowest_mse_target_{i}"]
for i in range(M)
]
) # Shape (M, splits) = (10, 5). Number of components that achieves the lowest MSE for each target for each split.
26 changes: 26 additions & 0 deletions examples/fit_predict_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ikpls.jax_ikpls_alg_1 import PLS
import numpy as np

if __name__ == '__main__':
N = 100 # Number of samples.
K = 50 # Number of features.
M = 10 # Number of targets.
A = 20 # Number of latent variables (PLS components).

X = np.random.uniform(size=(N, K)).astype(np.float64)
Y = np.random.uniform(size=(N, M)).astype(np.float64)

jax_ikpls_alg_1 = PLS()
jax_ikpls_alg_1.fit(X, Y, A)

y_pred = jax_ikpls_alg_1.predict(X) # Has shape (A, N, M) = (20, 100, 10). Contains a prediction for all possible number of components up to and including A.
y_pred_20_components = jax_ikpls_alg_1.predict(X, n_components=20) # Has shape (N, M) = (100, 10).
np.allclose(y_pred_20_components, y_pred[19], atol=0, rtol=1e14) # True. Exact equality might not hold due to numerical differences.

# The internal model parameters can be accessed as follows:
jax_ikpls_alg_1.B # Regression coefficients tensor of shape (A, K, M) = (20, 50, 10).
jax_ikpls_alg_1.W # X weights matrix of shape (K, A) = (50, 20).
jax_ikpls_alg_1.P # X loadings matrix of shape (K, A) = (50, 20).
jax_ikpls_alg_1.Q # Y loadings matrix of shape (M, A) = (10, 20).
jax_ikpls_alg_1.R # X rotations matrix of shape (K, A) = (50, 20).
jax_ikpls_alg_1.T # X scores matrix of shape (N, A) = (100, 20). This is only computed for IKPLS Algorithm #1.
32 changes: 32 additions & 0 deletions examples/fit_predict_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from ikpls.numpy_ikpls import PLS
import numpy as np

if __name__ == "__main__":
N = 100 # Number of samples.
K = 50 # Number of features.
M = 10 # Number of targets.
A = 20 # Number of latent variables (PLS components).

# Using float64 is important for numerical stability.
X = np.random.uniform(size=(N, K)).astype(np.float64)
Y = np.random.uniform(size=(N, M)).astype(np.float64)

# The other PLS algorithms and implementations have the same interface for fit() and predict().
np_ikpls_alg_1 = PLS(algorithm=1)
np_ikpls_alg_1.fit(X, Y, A)

y_pred = np_ikpls_alg_1.predict(
X
) # Has shape (A, N, M) = (20, 100, 10). Contains a prediction for all possible number of components up to and including A.
y_pred_20_components = np_ikpls_alg_1.predict(
X, n_components=20
) # Has shape (N, M) = (100, 10).
(y_pred_20_components == y_pred[19]).all() # True

# The internal model parameters can be accessed as follows:
np_ikpls_alg_1.B # Regression coefficients tensor of shape (A, K, M) = (20, 50, 10).
np_ikpls_alg_1.W # X weights matrix of shape (K, A) = (50, 20).
np_ikpls_alg_1.P # X loadings matrix of shape (K, A) = (50, 20).
np_ikpls_alg_1.Q # Y loadings matrix of shape (M, A) = (10, 20).
np_ikpls_alg_1.R # X rotations matrix of shape (K, A) = (50, 20).
np_ikpls_alg_1.T # X scores matrix of shape (N, A) = (100, 20). This is only computed for IKPLS Algorithm #1.
1 change: 1 addition & 0 deletions examples/gradient_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO: TBA For now, look for check_gradient_pls() in ../tests/test_ikpls.py if you can not wait to use this feature.
3 changes: 0 additions & 3 deletions ikpls/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from . import jax_ikpls_alg_1, jax_ikpls_alg_2, jax_ikpls_base, numpy_ikpls

__all__ = ["jax_ikpls_alg_1", "jax_ikpls_alg_2", "jax_ikpls_base", "numpy_ikpls"]
Loading

0 comments on commit 9bf5e5e

Please sign in to comment.