Skip to content

Commit

Permalink
Fix non-vectorized case w 21cmEMU
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielaBreitman committed Apr 10, 2024
1 parent 31b6e0e commit 6ea1fe9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
15 changes: 11 additions & 4 deletions src/py21cmmc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,9 +1317,9 @@ def build_model_data(self, ctx):
astro_params = self._update_params(astro_params).defining_dict
astro_params = {k: astro_params[k] for k in self.astro_param_keys}
if (
all(isinstance(v, (np.ndarray, list, int, float)) for v in values)
all(isinstance(v, (np.ndarray, list)) for v in values)
and len(values) > 0
):
):
lengths = [len(v) for v in values]
if lengths.count(lengths[0]) != len(lengths):
raise ValueError(
Expand All @@ -1329,9 +1329,16 @@ def build_model_data(self, ctx):
for t in zip(*values):
ap.append(dict(zip(keys, t)))
astro_params = np.array(ap, dtype=object)
if (
all(isinstance(v, (float, int)) for v in values)
and len(values) > 0
):
astro_params = dict(zip(keys, values))
astro_params = np.array([astro_params], dtype=object)

Check warning on line 1337 in src/py21cmmc/core.py

View check run for this annotation

Codecov / codecov/patch

src/py21cmmc/core.py#L1336-L1337

Added lines #L1336 - L1337 were not covered by tests
logger.debug(f"AstroParams: {astro_params}")

n = len(astro_params)
theta, outputs, errors = self.emulator.predict(astro_params=astro_params)

if self.io_options["cache_dir"] is not None:
if len(astro_params.shape) == 2:
pars = astro_params[0]
Expand All @@ -1347,7 +1354,7 @@ def build_model_data(self, ctx):
logger.debug(f"Adding {self.ctx_variables} to context data")
for key in self.ctx_variables:
try:
ctx.add(key + self.name, getattr(outputs, key))
ctx.add(key + self.name, getattr(outputs, key) if n > 1 else getattr(outputs, key)[np.newaxis,...])
except AttributeError:
try:
ctx.add(key + self.name, errors[key])
Expand Down
4 changes: 2 additions & 2 deletions src/py21cmmc/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,8 +1436,8 @@ def reduce_data(self, ctx):

def computeLikelihood(self, model):
"""Compute the likelihood."""
n = model["xHI"].shape[0]
xHI = np.atleast_2d(model["xHI"])
n = xHI.shape[0]

Check warning on line 1440 in src/py21cmmc/likelihood.py

View check run for this annotation

Codecov / codecov/patch

src/py21cmmc/likelihood.py#L1440

Added line #L1440 was not covered by tests
lnprob = np.zeros(n)
for i in range(n):
if self._require_spline:
Expand All @@ -1462,7 +1462,7 @@ def computeLikelihood(self, model):
lnprob[i] += self.lnprob(model_spline(z), data, sigma_t)

logger.debug(f"Neutral fraction Likelihood computed: {lnprob}")
return lnprob
return lnprob.squeeze()

Check warning on line 1465 in src/py21cmmc/likelihood.py

View check run for this annotation

Codecov / codecov/patch

src/py21cmmc/likelihood.py#L1465

Added line #L1465 was not covered by tests

def lnprob(self, model, data, sigma):
"""Compute the log prob given a model, data and error."""
Expand Down

0 comments on commit 6ea1fe9

Please sign in to comment.