Skip to content

Commit

Permalink
Adding XLA translation rules for JAX extension (#9)
Browse files Browse the repository at this point in the history
* starting to implement XLA ops

* full XLA implementation of factor

* testing jax jit

* abstracting the primitive build process

* adding implementation of solve

* adding norm op

* adding dot_tril

* adding matmul

* adding conditional_mean

* dealing with multiple outputs issue

* adding numpyro tutorial

* adding numpyro distribution and more tutorial details

* removing extra import
  • Loading branch information
dfm authored Oct 29, 2020
1 parent 795a21c commit 476a442
Show file tree
Hide file tree
Showing 11 changed files with 1,056 additions and 158 deletions.
178 changes: 164 additions & 14 deletions docs/tutorials/first.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.5.2
# jupytext_version: 1.6.0
# kernelspec:
# display_name: Python 3
# language: python
Expand Down Expand Up @@ -36,7 +36,10 @@
np.random.seed(42)

t = np.sort(
np.append(np.random.uniform(0, 3.8, 57), np.random.uniform(5.5, 10, 68),)
np.append(
np.random.uniform(0, 3.8, 57),
np.random.uniform(5.5, 10, 68),
)
) # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
Expand Down Expand Up @@ -180,11 +183,13 @@ def neg_log_like(params, gp):
# +
import emcee

prior_sigma = 2.0


def log_prob(params, gp):
gp = set_params(params, gp)
return (
gp.log_likelihood(y) - 0.5 * np.sum((params / 5.0) ** 2),
gp.log_likelihood(y) - 0.5 * np.sum((params / prior_sigma) ** 2),
gp.kernel.get_psd(omega),
)

Expand All @@ -195,6 +200,8 @@ def log_prob(params, gp):
coords.shape[0], coords.shape[1], log_prob, args=(gp,)
)
state = sampler.run_mcmc(coords, 2000, progress=True)
sampler.reset()
state = sampler.run_mcmc(state, 5000, progress=True)
# -

# After running our MCMC, we can plot the predictions that the model makes for a handful of samples from the chain.
Expand Down Expand Up @@ -239,16 +246,16 @@ def log_prob(params, gp):

with pm.Model() as model:

mean = pm.Normal("mean", mu=0.0, sigma=5.0)
jitter = pm.Lognormal("jitter", mu=0.0, sigma=5.0)
mean = pm.Normal("mean", mu=0.0, sigma=prior_sigma)
jitter = pm.Lognormal("jitter", mu=0.0, sigma=prior_sigma)

sigma1 = pm.Lognormal("sigma1", mu=0.0, sigma=5.0)
rho1 = pm.Lognormal("rho1", mu=0.0, sigma=5.0, testval=np.exp(soln.x[2]))
tau = pm.Lognormal("tau", mu=0.0, sigma=5.0)
sigma1 = pm.Lognormal("sigma1", mu=0.0, sigma=prior_sigma)
rho1 = pm.Lognormal("rho1", mu=0.0, sigma=prior_sigma)
tau = pm.Lognormal("tau", mu=0.0, sigma=prior_sigma)
term1 = theano_terms.SHOTerm(sigma=sigma1, rho=rho1, tau=tau)

sigma2 = pm.Lognormal("sigma2", mu=0.0, sigma=5.0)
rho2 = pm.Lognormal("rho2", mu=0.0, sigma=5.0)
sigma2 = pm.Lognormal("sigma2", mu=0.0, sigma=prior_sigma)
rho2 = pm.Lognormal("rho2", mu=0.0, sigma=prior_sigma)
term2 = theano_terms.SHOTerm(sigma=sigma2, rho=rho2, Q=0.25)

kernel = term1 + term2
Expand All @@ -261,9 +268,10 @@ def log_prob(params, gp):
trace = pm.sample(
tune=1000,
draws=1000,
target_accept=0.95,
target_accept=0.8,
init="adapt_full",
cores=1,
cores=2,
chains=2,
random_seed=34923,
)
# -
Expand All @@ -284,5 +292,147 @@ def log_prob(params, gp):
_ = plt.title("posterior psd using PyMC3")
# -

# In this particular case, the runtime with PyMC3 is somewhat longer than with emcee, but it also produced more effective samples.
# If we were to run a higher dimensional model (with more parameters) then PyMC3 will generally be substantially faster.
# ## Posterior inference using numpyro
#
# Since celerite2 includes support for JAX as well as Theano, you can also use tools like [numpyro](https://github.com/pyro-ppl/numpyro) for inference.
# The following is similar to previous PyMC3 example, but the main difference is that (for technical reasons related to how JAX works) `SHOTerm`s cannot be used in combination with `jax.jit`, so we need to explicitly specify the terms as "underdamped" (`UnderdampedSHOTerm`) or "overdamped" (`OverdampedSHOTerm`).

# +
from jax.config import config

config.update("jax_enable_x64", True)

from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import celerite2.jax
from celerite2.jax import terms as jax_terms


def numpyro_model(t, yerr, y=None):
mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
jitter = numpyro.sample("jitter", dist.LogNormal(0.0, prior_sigma))

sigma1 = numpyro.sample("sigma1", dist.LogNormal(0.0, prior_sigma))
rho1 = numpyro.sample("rho1", dist.LogNormal(0.0, prior_sigma))
tau = numpyro.sample("tau", dist.LogNormal(0.0, prior_sigma))
term1 = jax_terms.UnderdampedSHOTerm(sigma=sigma1, rho=rho1, tau=tau)

sigma2 = numpyro.sample("sigma2", dist.LogNormal(0.0, prior_sigma))
rho2 = numpyro.sample("rho2", dist.LogNormal(0.0, prior_sigma))
term2 = jax_terms.OverdampedSHOTerm(sigma=sigma2, rho=rho2, Q=0.25)

kernel = term1 + term2
gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
gp.compute(t, diag=yerr ** 2 + jitter, check_sorted=False)

numpyro.sample("obs", gp.numpyro_dist(), obs=y)
numpyro.deterministic("psd", kernel.get_psd(omega))


nuts_kernel = NUTS(numpyro_model, dense_mass=True)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=2)
rng_key = random.PRNGKey(34923)
# %time mcmc.run(rng_key, t, yerr, y=y)
# -

# This runtime was similar to the PyMC3 result from above, and (as we'll see below) the convergence is also similar.
# Any difference in runtime will probably disappear for more computationally expensive models, but this interface is looking pretty great here!
#
# As above, we can plot the posterior expectations for the power spectrum:

# +
psds = np.asarray(mcmc.get_samples()["psd"])

q = np.percentile(psds, [16, 50, 84], axis=0)

plt.loglog(freq, q[1], color="C0")
plt.fill_between(freq, q[0], q[2], color="C0", alpha=0.1)

plt.xlim(freq.min(), freq.max())
plt.xlabel("frequency [1 / day]")
plt.ylabel("power [day ppt$^2$]")
_ = plt.title("posterior psd using numpyro")
# -

# ## Comparison
#
# Finally, let's compare the results of these different inference methods a bit more quantitaively.
# First, let's look at the posterior constraint on the period of the underdamped harmonic oscillator, the effective period of the oscillatory signal.

# +
import arviz as az

emcee_data = az.from_emcee(
sampler,
var_names=[
"mean",
"log_sigma1",
"log_rho1",
"log_tau",
"log_sigma2",
"log_rho2",
"log_jitter",
],
)
for k in emcee_data.posterior.data_vars:
if k.startswith("log_"):
emcee_data.posterior[k[4:]] = np.exp(emcee_data.posterior[k])

with model:
pm_data = az.from_pymc3(trace)

numpyro_data = az.from_numpyro(mcmc)

bins = np.linspace(1.5, 2.75, 25)
plt.hist(
np.asarray((emcee_data.posterior["rho1"].T)).flatten(),
bins,
histtype="step",
density=True,
label="emcee",
)
plt.hist(
np.asarray((pm_data.posterior["rho1"].T)).flatten(),
bins,
histtype="step",
density=True,
label="PyMC3",
)
plt.hist(
np.asarray((numpyro_data.posterior["rho1"].T)).flatten(),
bins,
histtype="step",
density=True,
label="numpyro",
)
plt.legend()
plt.yticks([])
plt.xlabel(r"$\rho_1$")
_ = plt.ylabel(r"$p(\rho_1)$")
# -

# That looks pretty consistent.
#
# Next we can look at the [ArviZ](https://arviz-devs.github.io/arviz/) summary for each method to see how the posterior expectations and convergence diagnostics look.

az.summary(
emcee_data,
var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
)

az.summary(
pm_data,
var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
)

az.summary(
numpyro_data,
var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
)

# Overall these results are consistent, but the $\hat{R}$ values are a bit high for the emcee run, so I'd probably run that for longer.
# Either way, for models like these, PyMC3 and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework.
18 changes: 9 additions & 9 deletions python/celerite2/backprop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ auto solve_fwd(py::array_t<double, py::array::c_style> U, py::array_t<double, py
CONST_MATRIX(Eigen::Dynamic, Y_, Ybuf, N, nrhs); \
MATRIX(Eigen::Dynamic, X_, Xbuf, N, nrhs); \
MATRIX(Eigen::Dynamic, Z_, Zbuf, N, nrhs); \
MATRIX(Eigen::Dynamic, F_, Fbuf, N, J *nrhs); \
MATRIX(Eigen::Dynamic, G_, Gbuf, N, J *nrhs); \
MATRIX(Eigen::Dynamic, F_, Fbuf, N, (J * nrhs)); \
MATRIX(Eigen::Dynamic, G_, Gbuf, N, (J * nrhs)); \
celerite2::core::solve(U_, P_, d_, W_, Y_, X_, Z_, F_, G_); \
} \
}
Expand Down Expand Up @@ -162,8 +162,8 @@ auto solve_rev(py::array_t<double, py::array::c_style> U, py::array_t<double, py
CONST_MATRIX(Eigen::Dynamic, Y_, Ybuf, N, nrhs); \
CONST_MATRIX(Eigen::Dynamic, X_, Xbuf, N, nrhs); \
CONST_MATRIX(Eigen::Dynamic, Z_, Zbuf, N, nrhs); \
CONST_MATRIX(Eigen::Dynamic, F_, Fbuf, N, J *nrhs); \
CONST_MATRIX(Eigen::Dynamic, G_, Gbuf, N, J *nrhs); \
CONST_MATRIX(Eigen::Dynamic, F_, Fbuf, N, (J * nrhs)); \
CONST_MATRIX(Eigen::Dynamic, G_, Gbuf, N, (J * nrhs)); \
\
CONST_MATRIX(Eigen::Dynamic, bX_, bXbuf, N, nrhs); \
MATRIX(Eigen::Dynamic, bY_, bYbuf, N, nrhs); \
Expand Down Expand Up @@ -200,7 +200,7 @@ auto dot_tril_fwd(py::array_t<double, py::array::c_style> U, py::array_t<double,
} else { \
CONST_MATRIX(Eigen::Dynamic, Y_, Ybuf, N, nrhs); \
MATRIX(Eigen::Dynamic, Z_, Zbuf, N, nrhs); \
MATRIX(Eigen::Dynamic, F_, Fbuf, N, J *nrhs); \
MATRIX(Eigen::Dynamic, F_, Fbuf, N, (J * nrhs)); \
celerite2::core::dot_tril(U_, P_, d_, W_, Y_, Z_, F_); \
} \
}
Expand Down Expand Up @@ -370,8 +370,8 @@ auto matmul_fwd(py::array_t<double, py::array::c_style> d, py::array_t<double, p
CONST_MATRIX(Eigen::Dynamic, Y_, Ybuf, N, nrhs); \
MATRIX(Eigen::Dynamic, X_, Xbuf, N, nrhs); \
MATRIX(Eigen::Dynamic, Z_, Zbuf, N, nrhs); \
MATRIX(Eigen::Dynamic, F_, Fbuf, N, J *nrhs); \
MATRIX(Eigen::Dynamic, G_, Gbuf, N, J *nrhs); \
MATRIX(Eigen::Dynamic, F_, Fbuf, N, (J * nrhs)); \
MATRIX(Eigen::Dynamic, G_, Gbuf, N, (J * nrhs)); \
celerite2::core::matmul(d_, U_, W_, P_, Y_, X_, Z_, F_, G_); \
} \
}
Expand Down Expand Up @@ -429,8 +429,8 @@ auto matmul_rev(py::array_t<double, py::array::c_style> d, py::array_t<double, p
CONST_MATRIX(Eigen::Dynamic, Y_, Ybuf, N, nrhs); \
CONST_MATRIX(Eigen::Dynamic, X_, Xbuf, N, nrhs); \
CONST_MATRIX(Eigen::Dynamic, Z_, Zbuf, N, nrhs); \
CONST_MATRIX(Eigen::Dynamic, F_, Fbuf, N, J *nrhs); \
CONST_MATRIX(Eigen::Dynamic, G_, Gbuf, N, J *nrhs); \
CONST_MATRIX(Eigen::Dynamic, F_, Fbuf, N, (J * nrhs)); \
CONST_MATRIX(Eigen::Dynamic, G_, Gbuf, N, (J * nrhs)); \
\
CONST_MATRIX(Eigen::Dynamic, bX_, bXbuf, N, nrhs); \
MATRIX(Eigen::Dynamic, bY_, bYbuf, N, nrhs); \
Expand Down
22 changes: 19 additions & 3 deletions python/celerite2/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
# -*- coding: utf-8 -*-

__all__ = ["terms", "GaussianProcess"]
import logging

from . import terms
from .celerite2 import GaussianProcess
logger = logging.getLogger(__name__)

from jax.config import config # noqa isort:skip

if not config.read("jax_enable_x64"):
logger.warning(
"celerite2.jax only works with dtype float64. "
"To enable, run (before importing jax or celerite2.jax):\n"
">>> from jax.config import config\n"
">>> config.update('jax_enable_x64', True)"
)


__all__ = ["terms", "GaussianProcess", "CeleriteNormal"]

from . import terms # noqa isort:skip
from .celerite2 import GaussianProcess # noqa isort:skip
from .distribution import CeleriteNormal
7 changes: 4 additions & 3 deletions python/celerite2/jax/celerite2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .. import backprop, driver
from ..ext import BaseGaussianProcess
from . import ops
from . import distribution, ops


class GaussianProcess(BaseGaussianProcess):
Expand Down Expand Up @@ -33,8 +33,6 @@ def do_compute(self, quiet):
)

def check_sorted(self, t):
if np.any(np.diff(t) < 0.0):
raise ValueError("the input coordinates must be sorted")
return t

def do_solve(self, y):
Expand All @@ -54,3 +52,6 @@ def tensordot(self, a, b):

def diagdot(self, a, b):
return np.einsum("ij,ij->j", a, b)

def numpyro_dist(self):
return distribution.CeleriteNormal(self)
38 changes: 38 additions & 0 deletions python/celerite2/jax/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-

__all__ = ["CeleriteNormal"]
from jax import numpy as jnp
from jax import random as random

try:
import numpyro # noqa
except ImportError:

class CeleriteNormal:
def __init__(self, *args, **kwargs):
raise ImportError(
"pymc3 is required to use the CeleriteNormal distribution"
)


else:
from numpyro import distributions as dist

class CeleriteNormal(dist.Distribution):
support = dist.constraints.real_vector

def __init__(self, gp, validate_args=None):
self.gp = gp
super().__init__(
batch_shape=(),
event_shape=jnp.shape(self.gp._t),
validate_args=validate_args,
)

@dist.util.validate_sample
def log_prob(self, value):
return self.gp.log_likelihood(value)

def sample(self, key, sample_shape=()):
eps = random.normal(key, shape=self.event_shape + sample_shape)
return jnp.moveaxis(self.gp.dot_tril(eps), 0, -1)
Loading

0 comments on commit 476a442

Please sign in to comment.