Skip to content

Commit

Permalink
tests pass new version of jax and chainconsumer
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminpope committed Apr 2, 2024
1 parent a87e80b commit 936f513
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 60 deletions.
107 changes: 75 additions & 32 deletions notebooks/01_Injection_Recovery_ControlPoints.ipynb

Large diffs are not rendered by default.

186 changes: 163 additions & 23 deletions notebooks/02_Fitting.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/ticktack/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def load_data(self, file_name, oversample=1008, burnin_oversample=1, burnin_time
self.burnin_oversample = burnin_oversample
self.offset = jnp.mean(self.d14c_data[:num_offset])
self.annual = jnp.arange(self.start, self.end + 1)
self.mask = jnp.in1d(self.annual, self.time_data)
self.mask = jnp.isin(self.annual, self.time_data)
self.time_data_fine = jnp.linspace(jnp.min(self.annual), jnp.max(self.annual) + 2,
(self.annual.size + 1) * self.oversample)
try:
Expand Down Expand Up @@ -1149,7 +1149,7 @@ def compile(self):
self.time_data_fine = jnp.linspace(jnp.min(self.annual), jnp.max(self.annual) + 2,
(self.annual.size + 1) * self.oversample)
for sf in self.MultiFitter:
sf.multi_mask = jnp.in1d(self.annual, sf.time_data)
sf.multi_mask = jnp.isin(self.annual, sf.time_data)
if self.production_model == 'control points':
self.control_points_time = jnp.arange(self.start, self.end)
self.production = self.multi_interp_gp
Expand Down
3 changes: 1 addition & 2 deletions src/ticktack/ticktack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from jax import jit
import jax
from functools import partial
from jax.config import config

from jax.lax import cond, dynamic_update_slice, fori_loop, dynamic_slice
import diffrax
Expand All @@ -17,7 +16,7 @@
import pkg_resources
from typing import Union

config.update("jax_enable_x64", True) # run in 64 bit by default or else you will lack the dynamic range required
jax.config.update("jax_enable_x64", True) # run in 64 bit by default or else you will lack the dynamic range required



Expand Down
2 changes: 1 addition & 1 deletion tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def SingleFitter_creation():
sf.annual = jnp.arange(sf.start, sf.end + 1)
sf.time_data_fine = jnp.linspace(jnp.min(sf.annual), jnp.max(sf.annual) + 2, (sf.annual.size + 1) * sf.oversample)
sf.offset = jnp.mean(sf.d14c_data[:4])
sf.mask = jnp.in1d(sf.annual, sf.time_data)
sf.mask = jnp.isin(sf.annual, sf.time_data)
sf.growth = sf.get_growth_vector("april-september")
return sf

Expand Down

0 comments on commit 936f513

Please sign in to comment.