Skip to content

Commit

Permalink
fix: correct dimension handling and stabilize sparse GP posterior
Browse files Browse the repository at this point in the history
- Fixed incorrect dimension usage for `y_cov_factor` by ensuring it matches the number of observations (`n_obs`) instead of landmarks (`n_landmarks`).
- Added a new helper function `add_projected_variance` to properly project the observation noise covariance and stabilize the covariance matrix.
- Replaced direct variance addition in `_LandmarksConditional` with the new `add_projected_variance` function to maintain numerical stability during Cholesky decomposition.
- Ensured the covariance matrix remains positive definite by adjusting diagonal elements based on a `jitter` threshold.

This commit resolves the dimension mismatch issue and preserves the stabilizing effect, ensuring correct posterior GP computations with uncertainty.
  • Loading branch information
katosh committed Oct 16, 2024
1 parent d6b62ea commit 344c2d8
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- bugfix DimensionalityEstimator dimensionality initialization
- implement 'fixed' gaussian proces type to allow more inducing points than datapoints
- implement `copy()` method for `Predictor` class
- fix: uncertainty computation of sparse GP

# v1.4.3

Expand Down
6 changes: 3 additions & 3 deletions mellon/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax.numpy import diag as diagonal
from jax.numpy.linalg import cholesky
from jax.scipy.linalg import solve_triangular
from .util import ensure_2d, stabilize, DEFAULT_JITTER, add_variance
from .util import ensure_2d, stabilize, DEFAULT_JITTER, add_variance, add_projected_variance
from .base_predictor import Predictor, ExpPredictor, PredictorTime
from .decomposition import DEFAULT_SIGMA

Expand Down Expand Up @@ -278,9 +278,9 @@ def __init__(
LLB = stabilize(LLB, jitter)
else:
logger.debug("Assuming y is not the mean of the GP.")
y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, xu.shape[0])
y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, x.shape[0])
sigma = None
LLB = add_variance(LLB, y_cov_factor, jitter=jitter)
LLB = add_projected_variance(LLB, A, y_cov_factor, jitter=jitter)

L_B = cholesky(LLB)
r = y - mu
Expand Down
27 changes: 27 additions & 0 deletions mellon/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,33 @@ def add_variance(K, M=None, jitter=DEFAULT_JITTER):
return K


def add_projected_variance(K, A, y_cov_factor, jitter=DEFAULT_JITTER):
"""
Adds the projected observation noise covariance to K and stabilizes it.
Parameters
----------
K : array_like, shape (n_landmarks, n_landmarks)
The initial covariance matrix.
A : array_like, shape (n_landmarks, n_obs)
The projection matrix from observations to inducing points.
y_cov_factor : array_like, shape (n_obs, n_obs)
The observation noise covariance matrix.
jitter : float, optional
A small number to stabilize the covariance matrix. Defaults to 1e-6.
Returns
-------
stabilized_K : array_like, shape (n_landmarks, n_landmarks)
The stabilized covariance matrix with added projected variance.
"""
noise = A @ y_cov_factor @ A.T
noise_diag = np.diag(noise)
diff = where(noise_diag < jitter, jitter - noise_diag, 0)
K = K + noise + np.diag(diff)
return K


def mle(nn_distances, d):
R"""
Nearest Neighbor distribution maximum likelihood estimate for log density
Expand Down

0 comments on commit 344c2d8

Please sign in to comment.