Skip to content

Commit

Permalink
[skip ci] wip - fixing math rendering in documentation - almost there
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane authored and thomaspinder committed Jul 9, 2024
1 parent c16e7ff commit 2d3951a
Show file tree
Hide file tree
Showing 42 changed files with 326 additions and 297 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@ package-lock.json
node_modules/

docs/api
docs/examples/*.md
26 changes: 13 additions & 13 deletions docs/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ in `GPJax` with its corresponding mathematical quantity.

| On paper | GPJax code | Description |
| ------------------------------------------- | ---------- | ------------------------------------------------------------------------------- |
| $`n`$ | n | Number of train inputs |
| $`\boldsymbol{x} = (x_1,\dotsc,x_{n})`$ | x | Train inputs |
| $`\boldsymbol{y} = (y_1,\dotsc,y_{n})`$ | y | Train labels |
| $`\boldsymbol{t}`$ | t | Test inputs |
| $`f(\cdot)`$ | f | Latent function modelled as a GP |
| $`f({\boldsymbol{x}})`$ | fx | Latent function at inputs $`\boldsymbol{x}`$ |
| $`\boldsymbol{\mu}_{\boldsymbol{x}}`$ | mux | Prior mean at inputs $`\boldsymbol{x}`$ |
| $`\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}`$ | Kxx | Kernel Gram matrix at inputs $`\boldsymbol{x}`$ |
| $`\mathbf{L}_{\boldsymbol{x}}`$ | Lx | Lower Cholesky decomposition of $`\boldsymbol{K}_{\boldsymbol{x}\boldsymbol{x}}`$ |
| $`\mathbf{K}_{\boldsymbol{t}\boldsymbol{x}}`$ | Ktx | Cross-covariance between inputs $`\boldsymbol{t}`$ and $`\boldsymbol{x}`$ |
| $n$ | n | Number of train inputs |
| $\boldsymbol{x} = (x_1,\dotsc,x_{n})$ | x | Train inputs |
| $\boldsymbol{y} = (y_1,\dotsc,y_{n})$ | y | Train labels |
| $\boldsymbol{t}$ | t | Test inputs |
| $f(\cdot)$ | f | Latent function modelled as a GP |
| $f({\boldsymbol{x}})$ | fx | Latent function at inputs $\boldsymbol{x}$ |
| $\boldsymbol{\mu}_{\boldsymbol{x}}$ | mux | Prior mean at inputs $\boldsymbol{x}$ |
| $\mathbf{K}_{\boldsymbol{x}\boldsymbol{x}}$ | Kxx | Kernel Gram matrix at inputs $\boldsymbol{x}$ |
| $\mathbf{L}_{\boldsymbol{x}}$ | Lx | Lower Cholesky decomposition of $\boldsymbol{K}_{\boldsymbol{x}\boldsymbol{x}}$ |
| $\mathbf{K}_{\boldsymbol{t}\boldsymbol{x}}$ | Ktx | Cross-covariance between inputs $\boldsymbol{t}$ and $\boldsymbol{x}$ |

## Sparse Gaussian process notation

| On paper | GPJax code | Description |
| ------------------------------------- | ---------- | ------------------------- |
| $`m`$ | m | Number of inducing inputs |
| $`\boldsymbol{z} = (z_1,\dotsc,z_{m})`$ | z | Inducing inputs |
| $`\boldsymbol{u} = (u_1,\dotsc,u_{m})`$ | u | Inducing outputs |
| $m$ | m | Number of inducing inputs |
| $\boldsymbol{z} = (z_1,\dotsc,z_{m})$ | z | Inducing inputs |
| $\boldsymbol{u} = (u_1,\dotsc,u_{m})$ | u | Inducing outputs |

## Package style

Expand Down
8 changes: 4 additions & 4 deletions docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class Prior(AbstractPrior):
[mean](https://docs.jaxgaussianprocesses.com/api/mean_functions/)
and [kernel](https://docs.jaxgaussianprocesses.com/api/kernels/base/) function.
A Gaussian process prior parameterised by a mean function $`m(\cdot)`$ and a kernel
function $`k(\cdot, \cdot)`$ is given by
$`p(f(\cdot)) = \mathcal{GP}(m(\cdot), k(\cdot, \cdot))`$.
A Gaussian process prior parameterised by a mean function $m(\cdot)$ and a kernel
function $k(\cdot, \cdot)$ is given by
$p(f(\cdot)) = \mathcal{GP}(m(\cdot), k(\cdot, \cdot))$.
To invoke a `Prior` distribution, a kernel and mean function must be specified.
Expand Down Expand Up @@ -91,4 +91,4 @@ We adopt the following convention when documenting objects:

!!! attention "Note"

Inline math in docstrings needs to be rendered within both `$` and `` symbols to be correctly rendered by MkDocs. For instance, where one would typically write `$k(x,y)$` in standard LaTeX, in docstrings you are required to write ``$`k(x,y)`$`` in order for the math to be correctly rendered by MkDocs.
Inline math in docstrings needs to be rendered within both $ and `` symbols to be correctly rendered by MkDocs. For instance, where one would typically write $k(x,y)$ in standard LaTeX, in docstrings you are required to write `$k(x,y)$` in order for the math to be correctly rendered by MkDocs.
20 changes: 16 additions & 4 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
# or vice-versa. Typically, computing this metric requires solving a linear program.
# However, when $\mu$ and $\nu$ both belong to the family of multivariate Gaussian
# distributions, the solution is analytically given by
# $$W_2^2(\mu, \nu) = \lVert m_1- m_2 \rVert^2_2 + \operatorname{Tr}(S_1 + S_2 - 2(S_1^{1/2}S_2S_1^{1/2})^{1/2}),$$
#
# $$
# W_2^2(\mu, \nu) = \lVert m_1- m_2 \rVert^2_2 + \operatorname{Tr}(S_1 + S_2 - 2(S_1^{1/2}S_2S_1^{1/2})^{1/2}),
# $$
#
# where $\mu \sim \mathcal{N}(m_1, S_1)$ and $\nu\sim\mathcal{N}(m_2, S_2)$.
#
# ### Wasserstein barycentre
Expand All @@ -58,14 +62,22 @@
# $\bar{\mu}$ is the measure that minimises the average Wasserstein distance to all
# other measures in the set. More formally, the Wasserstein barycentre is the Fréchet
# mean on a Wasserstein space that we can write as
# $$\bar{\mu} = \operatorname{argmin}_{\mu\in\mathcal{P}_2(\theta)}\sum_{t=1}^T \alpha_t W_2^2(\mu, \mu_t),$$
#
# $$
# \bar{\mu} = \operatorname{argmin}_{\mu\in\mathcal{P}_2(\theta)}\sum_{t=1}^T \alpha_t W_2^2(\mu, \mu_t),
# $$
#
# where $\alpha\in\mathbb{R}^T$ is a weight vector that sums to 1.
#
# As with the Wasserstein distance, identifying the Wasserstein barycentre $\bar{\mu}$
# is often an computationally demanding optimisation problem. However, when all the
# measures admit a multivariate Gaussian density, the barycentre
# $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{S})$ has analytical solutions
# $$\bar{m} = \sum_{t=1}^T \alpha_t m_t\,, \quad \bar{S}=\sum_{t=1}^T\alpha_t (\bar{S}^{1/2}S_t\bar{S}^{1/2})^{1/2}\,. \qquad (\star)$$
#
# $$
# \bar{m} = \sum_{t=1}^T \alpha_t m_t\,, \quad \bar{S}=\sum_{t=1}^T\alpha_t (\bar{S}^{1/2}S_t\bar{S}^{1/2})^{1/2}\,. \qquad (\star)
# $$
#
# Identifying $\bar{S}$ is achieved through a fixed-point iterative update.
#
# ## Barycentre of Gaussian processes
Expand Down Expand Up @@ -265,7 +277,7 @@ def plot(
# distributions $\mu_1$ and $\mu_2$ to visualise the corresponding barycentre
# $\bar{\mu}$.
#
# ![](barycentre_gp.gif)
# ![](barycentres/barycentre_gp.gif)

# %% [markdown]
# ## System configuration
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
# $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs
# $\boldsymbol{x}$ sampled uniformly on $(-1., 1)$ and corresponding binary outputs
#
# $$\boldsymbol{y} = 0.5 * \text{sign}(\cos(2 * + \boldsymbol{\epsilon})) + 0.5, \quad \boldsymbol{\epsilon} \sim \mathcal{N} \left(\textbf{0}, \textbf{I} * (0.05)^{2} \right).$$
# $$
# \boldsymbol{y} = 0.5 * \text{sign}(\cos(2 * + \boldsymbol{\epsilon})) + 0.5, \quad \boldsymbol{\epsilon} \sim \mathcal{N} \left(\textbf{0}, \textbf{I} * (0.05)^{2} \right).
# $$
#
# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for
# later.
Expand Down
6 changes: 4 additions & 2 deletions docs/examples/intro_to_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# $\mathbf{y}$ for which we construct a model. The parameters $\theta$ of our
# model are unknown, and our goal is to conduct inference to determine their
# range of likely values. To achieve this, we apply Bayes' theorem
#
# \begin{align}
# \label{eq:BayesTheorem}
# p(\theta\,|\, \mathbf{y}) = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{p(\mathbf{y})} = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{\int_{\theta}p(\mathbf{y}, \theta)\mathrm{d}\theta}\,,
# \end{align}
#
# where $p(\mathbf{y}\,|\,\theta)$ denotes the _likelihood_, or model, and
# quantifies how likely the observed dataset $\mathbf{y}$ is, given the
# parameter estimate $\theta$. The _prior_ distribution $p(\theta)$ reflects our
Expand Down Expand Up @@ -386,7 +388,7 @@
# the set of _test points_.
# This process is visualised below
#
# ![](generating_process.png)
# ![](intro_to_gps/generating_process.png)
#
# As we shall go on to see, GPs offer an appealing workflow for scenarios such
# as this, all under a Bayesian framework.
Expand Down Expand Up @@ -499,7 +501,7 @@
# Optimising with respect to the marginal log-likelihood balances these two
# objectives when identifying the optimal solution, as visualised below.
#
# ![](decomposed_mll.png)
# ![](intro_to_gps/decomposed_mll.png)
#
# ## Conclusions
#
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/likelihoods_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
#
# The final method that is associated with a likelihood function in GPJax is the
# expected log-likelihood. This term is evaluated in the
# [stochastic variational Gaussian process](uncollaped_vi.py) in the ELBO term. For a
# [stochastic variational Gaussian process](uncollapsed_vi.py) in the ELBO term. For a
# variational approximation $q(f)= \mathcal{N}(f\mid m, S)$, the ELBO can be written as
# $$
# \begin{align}
Expand Down
60 changes: 24 additions & 36 deletions docs/javascripts/katex.js
Original file line number Diff line number Diff line change
@@ -1,38 +1,26 @@
(function () {
'use strict';
document$.subscribe(({ body }) => {
renderMathInElement(body, {
delimiters: [
{ left: "$$", right: "$$", display: true },
{ left: "$", right: "$", display: false },
{ left: "\\(", right: "\\)", display: false },
{ left: "\\[", right: "\\]", display: true }
],
})
})

var katexMath = (function () {
var maths = document.querySelectorAll('.arithmatex'),
tex;

for (var i = 0; i < maths.length; i++) {
tex = maths[i].textContent || maths[i].innerText;
if (tex.startsWith('\\(') && tex.endsWith('\\)')) {
katex.render(tex.slice(2, -2), maths[i], {'displayMode': false});
} else if (tex.startsWith('\\[') && tex.endsWith('\\]')) {
katex.render(tex.slice(2, -2), maths[i], {'displayMode': true});
}
}
});

(function () {
var onReady = function onReady(fn) {
if (document.addEventListener) {
document.addEventListener("DOMContentLoaded", fn);
} else {
document.attachEvent("onreadystatechange", function () {
if (document.readyState === "interactive") {
fn();
}
});
}
};

onReady(function () {
if (typeof katex !== "undefined") {
katexMath();
}
});
})();

}());
// document.addEventListener("DOMContentLoaded", function() {
// renderMathInElement(document.body, {
// // customised options
// // • auto-render specific keys, e.g.:
// delimiters: [
// {left: '$$', right: '$$', display: true},
// {left: '$', right: '$', display: false},
// {left: '\\(', right: '\\)', display: false},
// {left: '\\[', right: '\\]', display: true}
// ],
// // • rendering keys, e.g.:
// throwOnError : false
// })
// })
20 changes: 20 additions & 0 deletions docs/scripts/gen_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pathlib import Path
import subprocess

EXECUTE = False
EXCLUDE = ["docs/examples/utils.py"]
ALLOW_ERRORS = False


for file in Path("docs/").glob("examples/*.py"):
if file.as_posix() in EXCLUDE:
continue

out_file = file.with_suffix(".md")

command = "jupytext --to markdown "
command += f"{'--execute ' if EXECUTE else ''}"
command += f"{'--allow-errors ' if ALLOW_ERRORS else ''}"
command += f"{file} --output {out_file}"

subprocess.run(command, shell=True, check=False)
38 changes: 19 additions & 19 deletions docs/sharp_bits.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ Parameters such as the kernel's lengthscale or variance have their support defin
a constrained subset of the real-line. During gradient-based optimisation, as we
approach the set's boundary, it becomes possible that we could step outside of the
set's support and introduce a numerical and mathematical error into our model. For
example, consider the lengthscale parameter $`\ell`$, which we know must be strictly
positive. If at $`t^{\text{th}}`$ iterate, our current estimate of $`\ell`$ was
0.02 and our derivative informed us that $`\ell`$ should decrease, then if our
example, consider the lengthscale parameter $\ell$, which we know must be strictly
positive. If at $t^{\text{th}}$ iterate, our current estimate of $\ell$ was
0.02 and our derivative informed us that $\ell$ should decrease, then if our
learning rate is greater is than 0.03, we would end up with a negative variance term.
We visualise this issue below where the red cross denotes the invalid lengthscale value
that would be obtained, were we to optimise in the unconstrained parameter space.
Expand Down Expand Up @@ -95,23 +95,23 @@ The Gram matrix of a kernel, a concept that we explore more in our
symmetric positive definite matrix. As such, we
have a range of tools at our disposal to make subsequent operations on the covariance
matrix faster. One of these tools is the Cholesky factorisation that uniquely decomposes
any symmetric positive-definite matrix $`\mathbf{\Sigma}`$ by
any symmetric positive-definite matrix $\mathbf{\Sigma}$ by

```math
\begin{align}
\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^{\top}\,,
\end{align}
```
where $`\mathbf{L}`$ is a lower triangular matrix.
where $\mathbf{L}$ is a lower triangular matrix.

We make use of this result in GPJax when solving linear systems of equations of the
form $`\mathbf{A}\boldsymbol{x} = \boldsymbol{b}`$. Whilst seemingly abstract at first,
form $\mathbf{A}\boldsymbol{x} = \boldsymbol{b}$. Whilst seemingly abstract at first,
such problems are frequently encountered when constructing Gaussian process models. One
such example is frequently encountered in the regression setting for learning Gaussian
process kernel hyperparameters. Here we have labels
$`\boldsymbol{y} \sim \mathcal{N}(f(\boldsymbol{x}), \sigma^2\mathbf{I})`$ with $`f(\boldsymbol{x}) \sim \mathcal{N}(\boldsymbol{0}, \mathbf{K}_{\boldsymbol{xx}})`$ arising from zero-mean
Gaussian process prior and Gram matrix $`\mathbf{K}_{\boldsymbol{xx}}`$ at the inputs
$`\boldsymbol{x}`$. Here the marginal log-likelihood comprises the following form
$\boldsymbol{y} \sim \mathcal{N}(f(\boldsymbol{x}), \sigma^2\mathbf{I})$ with $f(\boldsymbol{x}) \sim \mathcal{N}(\boldsymbol{0}, \mathbf{K}_{\boldsymbol{xx}})$ arising from zero-mean
Gaussian process prior and Gram matrix $\mathbf{K}_{\boldsymbol{xx}}$ at the inputs
$\boldsymbol{x}$. Here the marginal log-likelihood comprises the following form

```math
\begin{align}
Expand All @@ -120,8 +120,8 @@ $`\boldsymbol{x}`$. Here the marginal log-likelihood comprises the following for
```

and the goal of inference is to maximise kernel hyperparameters (contained in the Gram
matrix $`\mathbf{K}_{\boldsymbol{xx}}`$) and likelihood hyperparameters (contained in the
noise covariance $`\sigma^2\mathbf{I}`$). Computing the marginal log-likelihood (and its
matrix $\mathbf{K}_{\boldsymbol{xx}}$) and likelihood hyperparameters (contained in the
noise covariance $\sigma^2\mathbf{I}$). Computing the marginal log-likelihood (and its
gradients), draws our attention to the term

```math
Expand All @@ -131,13 +131,13 @@ gradients), draws our attention to the term
```

then we can see a solution can be obtained by solving the corresponding system of
equations. By working with $`\mathbf{L} = \operatorname{chol}{\mathbf{A}}`$ instead of
$`\mathbf{A}`$, we save a significant amount of floating-point operations (flops) by
solving two triangular systems of equations (one for $`\mathbf{L}`$ and another for
$`\mathbf{L}^{\top}`$) instead of one dense system of equations. Solving two triangular systems
of equations has complexity $`\mathcal{O}(n^3/6)`$; a vast improvement compared to
regular solvers that have $`\mathcal{O}(n^3)`$ complexity in the number of datapoints
$`n`$.
equations. By working with $\mathbf{L} = \operatorname{chol}{\mathbf{A}}$ instead of
$\mathbf{A}$, we save a significant amount of floating-point operations (flops) by
solving two triangular systems of equations (one for $\mathbf{L}$ and another for
$\mathbf{L}^{\top}$) instead of one dense system of equations. Solving two triangular systems
of equations has complexity $\mathcal{O}(n^3/6)$; a vast improvement compared to
regular solvers that have $\mathcal{O}(n^3)$ complexity in the number of datapoints
$n$.

### The Cholesky drawback

Expand All @@ -152,7 +152,7 @@ factor since this requires that the input matrix is _numerically_ positive-defin
negative eigenvalues, this violates the requirements and results in a "Cholesky failure".

To resolve this, we apply some numerical _jitter_ to the diagonals of any Gram matrix.
Typically this is very small, with $`10^{-6}`$ being the system default. However,
Typically this is very small, with $10^{-6}$ being the system default. However,
for some problems, this amount may need to be increased.

## Slow-to-evaluate
Expand Down
4 changes: 2 additions & 2 deletions docs/stylesheets/extra.css
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ div.doc-contents:not(.first) {
}

/* Maximum space for text block */
.md-grid {
/* .md-grid {
max-width: 65%; /* or 100%, if you want to stretch to full-width */
}
/* }
12 changes: 6 additions & 6 deletions gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class Dataset:
r"""Base class for datasets.
Parameters:
Args:
X: input data.
y: output data.
"""
Expand All @@ -37,8 +37,8 @@ class Dataset:
y: Optional[Num[Array, "N Q"]] = None

def __post_init__(self) -> None:
r"""Checks that the shapes of $`X`$ and $`y`$ are compatible,
and provides warnings regarding the precision of $`X`$ and $`y`$."""
r"""Checks that the shapes of $X$ and $y$ are compatible,
and provides warnings regarding the precision of $X$ and $y$."""
_check_shape(self.X, self.y)
_check_precision(self.X, self.y)

Expand Down Expand Up @@ -75,7 +75,7 @@ def n(self) -> int:

@property
def in_dim(self) -> int:
r"""Dimension of the inputs, $`X`$."""
r"""Dimension of the inputs, $X$."""
return self.X.shape[1]

def tree_flatten(self):
Expand All @@ -89,7 +89,7 @@ def tree_unflatten(cls, aux_data, children):
def _check_shape(
X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
) -> None:
r"""Checks that the shapes of $`X`$ and $`y`$ are compatible."""
r"""Checks that the shapes of $X$ and $y$ are compatible."""
if X is not None and y is not None and X.shape[0] != y.shape[0]:
raise ValueError(
"Inputs, X, and outputs, y, must have the same number of rows."
Expand All @@ -110,7 +110,7 @@ def _check_shape(
def _check_precision(
X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
) -> None:
r"""Checks the precision of $`X`$ and $`y`."""
r"""Checks the precision of $X$ and $y`."""
if X is not None and X.dtype != jnp.float64:
warnings.warn(
"X is not of type float64. "
Expand Down
2 changes: 1 addition & 1 deletion gpjax/decision_making/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def dimensionality(self) -> int:

@dataclass
class ContinuousSearchSpace(AbstractSearchSpace):
"""The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension $`D`$."""
"""The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension $D$."""

lower_bounds: Float[Array, " D"]
upper_bounds: Float[Array, " D"]
Expand Down
Loading

0 comments on commit 2d3951a

Please sign in to comment.