Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss committed Sep 18, 2023
2 parents 76af7d8 + 6275d83 commit 7edd162
Show file tree
Hide file tree
Showing 90 changed files with 9,011 additions and 3,134 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ package.json
package-lock.json
node_modules/

docs/api
docs/api
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ jupytext --to py:percent example.ipynb
Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.

```python
from jax import config

config.update("jax_enable_x64", True)

import gpjax as gpx
import jax
from jax import grad, jit
Expand Down
File renamed without changes.
25 changes: 25 additions & 0 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"version": 1,
"project": "gpjax",
"project_url": "https://jaxgaussianprocesses.com/",
"repo": "..",
"install_command": ["python -mpip install {wheel_file}"],
"build_command": [
"PIP_NO_BUILD_ISOLATION=false python -m pip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}"
],
"branches": ["main"],
"matrix": {
"req": {
"poetry": [""]
}
},
"dvcs": "git",
"environment_type": "virtualenv",
"show_commit_url": "https://github.com/jaxgaussianprocesses/gpjax/commit/main",
"pythons": ["3.8"],
"benchmark_dir": ".",
"env_dir": ".asv/env",
"results_dir": ".asv/results",
"html_dir": ".asv/html",
"build_cache_size": 2
}
99 changes: 99 additions & 0 deletions benchmarks/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from jax import config

config.update("jax_enable_x64", True)

import jax.random as jr

from gpjax import kernels


class Kernels:
param_names = ["n_data", "dimensionality"]
params = [[10, 100, 500, 1000, 2000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
self.X = jr.uniform(
key=key, minval=-3.0, maxval=3.0, shape=(n_datapoints, n_dims)
)


class RBF(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.RBF(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class Matern12(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.Matern12(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class Matern32(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.Matern32(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class Matern52(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.Matern52(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class PoweredExponential(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.PoweredExponential(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class RationalQuadratic(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.RationalQuadratic(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class Polynomial(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.Polynomial(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class Linear(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.Linear(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)


class ArcCosine(Kernels):
def setup(self, n_datapoints: int, n_dims: int):
super().setup(n_datapoints, n_dims)
self.kernel = kernels.ArcCosine(active_dims=list(range(n_dims)))

def time_covfunc_call(self, n_datapoints: int, n_dims: int):
self.kernel.gram(self.X)
32 changes: 32 additions & 0 deletions benchmarks/linops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr
from sklearn.datasets import make_spd_matrix

from gpjax.linops import DenseLinearOperator


class LinOps:
param_names = ["n_data"]
params = [[10, 100, 200, 500, 1000]]

def setup(self, n_datapoints: int):
key = jr.PRNGKey(123)
self.X = jnp.asarray(make_spd_matrix(n_dim=n_datapoints, random_state=123))
self.y = jr.normal(key=key, shape=(n_datapoints, 1))
self.linop = DenseLinearOperator(matrix=self.X)

def time_root(self, n_datapoints: int):
self.linop.to_root()

def time_inverse(self, n_datapoints: int):
self.linop.inverse()

def time_logdet(self, n_datapoints: int):
self.linop.log_det()

def time_solve(self, n_datapoints: int):
self.linop.solve(self.y)
87 changes: 87 additions & 0 deletions benchmarks/objectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from jax import config

config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import jax.random as jr

import gpjax as gpx


class Gaussian:
param_names = [
"n_data",
"n_dims",
]
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.objective = gpx.ConjugateMLL()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
self.objective.step(self.posterior, self.data).block_until_ready()

def time_grad(self, n_datapoints: int, n_dims: int):
jax.block_until_ready(jax.grad(self.objective.step)(self.posterior, self.data))


class Bernoulli:
param_names = [
"n_data",
"n_dims",
]
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
self.y = jnp.where(jnp.sin(self.X[:, :1]) > 0, 1, 0)
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
self.objective.step(self.posterior, self.data).block_until_ready()

def time_grad(self, n_datapoints: int, n_dims: int):
jax.block_until_ready(jax.grad(self.objective.step)(self.posterior, self.data))


class Poisson:
param_names = [
"n_data",
"n_dims",
]
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x # latent function
self.y = jr.poisson(key, jnp.exp(f(self.X)))
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
self.objective.step(self.posterior, self.data).block_until_ready()

def time_grad(self, n_datapoints: int, n_dims: int):
jax.block_until_ready(jax.grad(self.objective.step)(self.posterior, self.data))
81 changes: 81 additions & 0 deletions benchmarks/predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from jax import config

config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jr

import gpjax as gpx


class Gaussian:
param_names = [
"n_test",
"n_dims",
]
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
self.xtest = jr.normal(key=subkey, shape=(n_test, n_dims))

def time_predict(self, n_test: int, n_dims: int):
self.posterior.predict(test_inputs=self.xtest, train_data=self.data)


class Bernoulli:
param_names = [
"n_test",
"n_dims",
]
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.y = jnp.array(jnp.where(self.y > 0, 1, 0), dtype=jnp.float64)
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
self.xtest = jr.normal(key=subkey, shape=(n_test, n_dims))

def time_predict(self, n_test: int, n_dims: int):
self.posterior.predict(test_inputs=self.xtest, train_data=self.data)


class Poisson:
param_names = [
"n_test",
"n_dims",
]
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x # latent function
self.y = jnp.array(jr.poisson(key, jnp.exp(f(self.X))), dtype=jnp.float64)
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
self.xtest = jr.normal(key=subkey, shape=(n_test, n_dims))

def time_predict(self, n_test: int, n_dims: int):
self.posterior.predict(test_inputs=self.xtest, train_data=self.data)
Loading

0 comments on commit 7edd162

Please sign in to comment.