-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding XLA translation rules for JAX extension (#9)
* starting to implement XLA ops * full XLA implementation of factor * testing jax jit * abstracting the primitive build process * adding implementation of solve * adding norm op * adding dot_tril * adding matmul * adding conditional_mean * dealing with multiple outputs issue * adding numpyro tutorial * adding numpyro distribution and more tutorial details * removing extra import
- Loading branch information
Showing
11 changed files
with
1,056 additions
and
158 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,22 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
__all__ = ["terms", "GaussianProcess"] | ||
import logging | ||
|
||
from . import terms | ||
from .celerite2 import GaussianProcess | ||
logger = logging.getLogger(__name__) | ||
|
||
from jax.config import config # noqa isort:skip | ||
|
||
if not config.read("jax_enable_x64"): | ||
logger.warning( | ||
"celerite2.jax only works with dtype float64. " | ||
"To enable, run (before importing jax or celerite2.jax):\n" | ||
">>> from jax.config import config\n" | ||
">>> config.update('jax_enable_x64', True)" | ||
) | ||
|
||
|
||
__all__ = ["terms", "GaussianProcess", "CeleriteNormal"] | ||
|
||
from . import terms # noqa isort:skip | ||
from .celerite2 import GaussianProcess # noqa isort:skip | ||
from .distribution import CeleriteNormal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
__all__ = ["CeleriteNormal"] | ||
from jax import numpy as jnp | ||
from jax import random as random | ||
|
||
try: | ||
import numpyro # noqa | ||
except ImportError: | ||
|
||
class CeleriteNormal: | ||
def __init__(self, *args, **kwargs): | ||
raise ImportError( | ||
"pymc3 is required to use the CeleriteNormal distribution" | ||
) | ||
|
||
|
||
else: | ||
from numpyro import distributions as dist | ||
|
||
class CeleriteNormal(dist.Distribution): | ||
support = dist.constraints.real_vector | ||
|
||
def __init__(self, gp, validate_args=None): | ||
self.gp = gp | ||
super().__init__( | ||
batch_shape=(), | ||
event_shape=jnp.shape(self.gp._t), | ||
validate_args=validate_args, | ||
) | ||
|
||
@dist.util.validate_sample | ||
def log_prob(self, value): | ||
return self.gp.log_likelihood(value) | ||
|
||
def sample(self, key, sample_shape=()): | ||
eps = random.normal(key, shape=self.event_shape + sample_shape) | ||
return jnp.moveaxis(self.gp.dot_tril(eps), 0, -1) |
Oops, something went wrong.