Skip to content

Commit

Permalink
Adding compatibility for recent PyMC versions (#29)
Browse files Browse the repository at this point in the history
* starting to add aesara support

* updating pre-commit hooks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* a few missing theano imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adding fallback for tutorials

* fixing tutorial compat

* updating coverage workflow

* don't run coveralls on mac and windows

* matrix.os syntax

* adding python_requires

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] authored Mar 11, 2021
1 parent cee9a0f commit a094aaf
Show file tree
Hide file tree
Showing 23 changed files with 141 additions and 221 deletions.
71 changes: 18 additions & 53 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,11 @@ on:
branches: [main]

jobs:
style:
name: "style"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
submodules: true
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install isort black black_nbconvert
- name: Check the style
run: |
isort -c python
black --check python
black_nbconvert --check .
build:
name: "py${{ matrix.python-version }} / ${{ matrix.os }}"
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.7, 3.8]
python-version: ["3.8", "3.9"]
os: [ubuntu-latest, windows-latest, macos-latest]
steps:
- uses: actions/checkout@v2
Expand All @@ -48,20 +25,20 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install --use-feature=2020-resolver -e ".[test]"
python -m pip install -e ".[test]"
env:
DISTUTILS_USE_SDK: 1
MSSdk: 1
- name: Run the unit tests
run: python -m pytest --cov celerite2 python/test
- uses: actions/upload-artifact@v2
if: ${{ matrix.os != 'windows-latest' }}
- name: Coveralls
if: startsWith(matrix.os, 'ubuntu')
uses: AndreMiras/coveralls-python-action@v20201129
with:
name: cov-${{ matrix.os }}-${{ matrix.python-version }}
path: .coverage
parallel: true
flag-name: Unit Tests

theano:
name: "theano"
runs-on: "ubuntu-latest"
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -95,13 +72,13 @@ jobs:
- name: Run the unit tests
shell: bash -l {0}
run: python -m pytest --cov celerite2 python/test/theano
- uses: actions/upload-artifact@v2
- name: Coveralls
uses: AndreMiras/coveralls-python-action@v20201129
with:
name: cov-theano
path: .coverage
parallel: true
flag-name: Unit Tests

jax:
name: "jax"
runs-on: "ubuntu-latest"
steps:
- uses: actions/checkout@v2
Expand All @@ -121,29 +98,17 @@ jobs:
- name: Run the unit tests
shell: bash -l {0}
run: python -m pytest --cov celerite2 python/test/jax
- uses: actions/upload-artifact@v2
- name: Coveralls
uses: AndreMiras/coveralls-python-action@v20201129
with:
name: cov-jax
path: .coverage
parallel: true
flag-name: Unit Tests

coverage:
name: coverage
needs: [build, theano, jax]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Coveralls Finished
uses: AndreMiras/coveralls-python-action@v20201129
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.7
- name: Download all artifacts
uses: actions/download-artifact@v2
- name: Merge and upload coverage
run: |
python -m pip install coveralls
find . -name \.coverage -exec coverage combine --append {} \;
coveralls
env:
GITHUB_TOKEN: ${{ secrets.github_token }}
parallel-finished: true
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install --use-feature=2020-resolver ".[tutorials]"
python -m pip install ".[tutorials]"
- name: Get theano compiledir
id: compiledir
Expand Down
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
rev: v3.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: debug-statements

- repo: https://github.com/timothycrosley/isort
rev: 5.0.4
- repo: https://github.com/PyCQA/isort
rev: "5.7.0"
hooks:
- id: isort
args: []
additional_dependencies: [toml]
exclude: docs/tutorials
- id: isort
args: []
additional_dependencies: [toml]
exclude: docs/tutorials

- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black

- repo: https://github.com/dfm/black_nbconvert
rev: stable
rev: v0.2.0
hooks:
- id: black_nbconvert
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ _celerite_ is an algorithm for fast and scalable Gaussian Process (GP)
Regression in one dimension and this library, _celerite2_ is a re-write of the
original [celerite project](https://celerite.readthedocs.io) to improve
numerical stability and integration with various machine learning frameworks. Documentation
for this version can be found [here](https://celerite2.readthedocs.io/en/latest/).
for this version can be found [here](https://celerite2.readthedocs.io/en/latest/).
This new implementation includes interfaces in Python and C++, with full support for
Theano/PyMC3 and JAX.

Expand Down
2 changes: 1 addition & 1 deletion c++/include/celerite2/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
#include "reverse.hpp"
#include "interface.hpp"

#endif // _CELERITE2_CORE_HPP_DEFINED_
#endif // _CELERITE2_CORE_HPP_DEFINED_
1 change: 0 additions & 1 deletion c++/test/catch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17594,4 +17594,3 @@ using Catch::Detail::Approx;
// end catch_reenable_warnings.h
// end catch.hpp
#endif // TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED

92 changes: 62 additions & 30 deletions docs/tutorials/first.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.6.0
# jupytext_version: 1.10.3
# kernelspec:
# display_name: Python 3
# language: python
Expand Down Expand Up @@ -248,28 +248,34 @@ def log_prob(params, gp):
with pm.Model() as model:

mean = pm.Normal("mean", mu=0.0, sigma=prior_sigma)
jitter = pm.Lognormal("jitter", mu=0.0, sigma=prior_sigma)

sigma1 = pm.Lognormal("sigma1", mu=0.0, sigma=prior_sigma)
rho1 = pm.Lognormal("rho1", mu=0.0, sigma=prior_sigma)
tau = pm.Lognormal("tau", mu=0.0, sigma=prior_sigma)
term1 = theano_terms.SHOTerm(sigma=sigma1, rho=rho1, tau=tau)
log_jitter = pm.Normal("log_jitter", mu=0.0, sigma=prior_sigma)

log_sigma1 = pm.Normal("log_sigma1", mu=0.0, sigma=prior_sigma)
log_rho1 = pm.Normal("log_rho1", mu=0.0, sigma=prior_sigma)
log_tau = pm.Normal("log_tau", mu=0.0, sigma=prior_sigma)
term1 = theano_terms.SHOTerm(
sigma=pm.math.exp(log_sigma1),
rho=pm.math.exp(log_rho1),
tau=pm.math.exp(log_tau),
)

sigma2 = pm.Lognormal("sigma2", mu=0.0, sigma=prior_sigma)
rho2 = pm.Lognormal("rho2", mu=0.0, sigma=prior_sigma)
term2 = theano_terms.SHOTerm(sigma=sigma2, rho=rho2, Q=0.25)
log_sigma2 = pm.Normal("log_sigma2", mu=0.0, sigma=prior_sigma)
log_rho2 = pm.Normal("log_rho2", mu=0.0, sigma=prior_sigma)
term2 = theano_terms.SHOTerm(
sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25
)

kernel = term1 + term2
gp = celerite2.theano.GaussianProcess(kernel, mean=mean)
gp.compute(t, diag=yerr ** 2 + jitter, quiet=True)
gp.compute(t, diag=yerr ** 2 + pm.math.exp(log_jitter), quiet=True)
gp.marginal("obs", observed=y)

pm.Deterministic("psd", kernel.get_psd(omega))

trace = pm.sample(
tune=1000,
draws=1000,
target_accept=0.8,
target_accept=0.9,
init="adapt_full",
cores=2,
chains=2,
Expand Down Expand Up @@ -304,6 +310,7 @@ def log_prob(params, gp):
config.update("jax_enable_x64", True)

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
Expand All @@ -315,20 +322,24 @@ def log_prob(params, gp):

def numpyro_model(t, yerr, y=None):
mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
jitter = numpyro.sample("jitter", dist.LogNormal(0.0, prior_sigma))
log_jitter = numpyro.sample("log_jitter", dist.Normal(0.0, prior_sigma))

sigma1 = numpyro.sample("sigma1", dist.LogNormal(0.0, prior_sigma))
rho1 = numpyro.sample("rho1", dist.LogNormal(0.0, prior_sigma))
tau = numpyro.sample("tau", dist.LogNormal(0.0, prior_sigma))
term1 = jax_terms.UnderdampedSHOTerm(sigma=sigma1, rho=rho1, tau=tau)
log_sigma1 = numpyro.sample("log_sigma1", dist.Normal(0.0, prior_sigma))
log_rho1 = numpyro.sample("log_rho1", dist.Normal(0.0, prior_sigma))
log_tau = numpyro.sample("log_tau", dist.Normal(0.0, prior_sigma))
term1 = jax_terms.UnderdampedSHOTerm(
sigma=jnp.exp(log_sigma1), rho=jnp.exp(log_rho1), tau=jnp.exp(log_tau)
)

sigma2 = numpyro.sample("sigma2", dist.LogNormal(0.0, prior_sigma))
rho2 = numpyro.sample("rho2", dist.LogNormal(0.0, prior_sigma))
term2 = jax_terms.OverdampedSHOTerm(sigma=sigma2, rho=rho2, Q=0.25)
log_sigma2 = numpyro.sample("log_sigma2", dist.Normal(0.0, prior_sigma))
log_rho2 = numpyro.sample("log_rho2", dist.Normal(0.0, prior_sigma))
term2 = jax_terms.OverdampedSHOTerm(
sigma=jnp.exp(log_sigma2), rho=jnp.exp(log_rho2), Q=0.25
)

kernel = term1 + term2
gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
gp.compute(t, diag=yerr ** 2 + jitter, check_sorted=False)
gp.compute(t, diag=yerr ** 2 + jnp.exp(log_jitter), check_sorted=False)

numpyro.sample("obs", gp.numpyro_dist(), obs=y)
numpyro.deterministic("psd", kernel.get_psd(omega))
Expand Down Expand Up @@ -379,9 +390,6 @@ def numpyro_model(t, yerr, y=None):
"log_jitter",
],
)
for k in emcee_data.posterior.data_vars:
if k.startswith("log_"):
emcee_data.posterior[k[4:]] = np.exp(emcee_data.posterior[k])

with model:
pm_data = az.from_pymc3(trace)
Expand All @@ -390,21 +398,21 @@ def numpyro_model(t, yerr, y=None):

bins = np.linspace(1.5, 2.75, 25)
plt.hist(
np.asarray((emcee_data.posterior["rho1"].T)).flatten(),
np.exp(np.asarray((emcee_data.posterior["log_rho1"].T)).flatten()),
bins,
histtype="step",
density=True,
label="emcee",
)
plt.hist(
np.asarray((pm_data.posterior["rho1"].T)).flatten(),
np.exp(np.asarray((pm_data.posterior["log_rho1"].T)).flatten()),
bins,
histtype="step",
density=True,
label="PyMC3",
)
plt.hist(
np.asarray((numpyro_data.posterior["rho1"].T)).flatten(),
np.exp(np.asarray((numpyro_data.posterior["log_rho1"].T)).flatten()),
bins,
histtype="step",
density=True,
Expand All @@ -422,17 +430,41 @@ def numpyro_model(t, yerr, y=None):

az.summary(
emcee_data,
var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
var_names=[
"mean",
"log_sigma1",
"log_rho1",
"log_tau",
"log_sigma2",
"log_rho2",
"log_jitter",
],
)

az.summary(
pm_data,
var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
var_names=[
"mean",
"log_sigma1",
"log_rho1",
"log_tau",
"log_sigma2",
"log_rho2",
"log_jitter",
],
)

az.summary(
numpyro_data,
var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
var_names=[
"mean",
"log_sigma1",
"log_rho1",
"log_tau",
"log_sigma2",
"log_rho2",
"log_jitter",
],
)

# Overall these results are consistent, but the $\hat{R}$ values are a bit high for the emcee run, so I'd probably run that for longer.
Expand Down
16 changes: 16 additions & 0 deletions docs/tutorials/notebook_setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.6.0
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# +
"""isort:skip_file"""

get_ipython().magic('config InlineBackend.figure_format = "retina"')
Expand Down
2 changes: 1 addition & 1 deletion python/celerite2/theano/celerite2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

__all__ = ["GaussianProcess", "ConditionalDistribution"]
import aesara_theano_fallback.tensor as tt
import numpy as np
from theano import tensor as tt

from ..core import BaseConditionalDistribution, BaseGaussianProcess
from . import ops
Expand Down
2 changes: 1 addition & 1 deletion python/celerite2/theano/distribution.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

__all__ = ["CeleriteNormal"]
import aesara_theano_fallback.tensor as tt
import numpy as np
from theano import tensor as tt

try:
import pymc3 # noqa
Expand Down
Loading

0 comments on commit a094aaf

Please sign in to comment.