diff --git a/src/py21cmmc/core.py b/src/py21cmmc/core.py index 3a7c7934..3e8b7c80 100644 --- a/src/py21cmmc/core.py +++ b/src/py21cmmc/core.py @@ -1317,9 +1317,9 @@ def build_model_data(self, ctx): astro_params = self._update_params(astro_params).defining_dict astro_params = {k: astro_params[k] for k in self.astro_param_keys} if ( - all(isinstance(v, (np.ndarray, list, int, float)) for v in values) + all(isinstance(v, (np.ndarray, list)) for v in values) and len(values) > 0 - ): + ): lengths = [len(v) for v in values] if lengths.count(lengths[0]) != len(lengths): raise ValueError( @@ -1329,9 +1329,16 @@ def build_model_data(self, ctx): for t in zip(*values): ap.append(dict(zip(keys, t))) astro_params = np.array(ap, dtype=object) + if ( + all(isinstance(v, (float, int)) for v in values) + and len(values) > 0 + ): + astro_params = dict(zip(keys, values)) + astro_params = np.array([astro_params], dtype=object) logger.debug(f"AstroParams: {astro_params}") - + n = len(astro_params) theta, outputs, errors = self.emulator.predict(astro_params=astro_params) + if self.io_options["cache_dir"] is not None: if len(astro_params.shape) == 2: pars = astro_params[0] @@ -1347,7 +1354,7 @@ def build_model_data(self, ctx): logger.debug(f"Adding {self.ctx_variables} to context data") for key in self.ctx_variables: try: - ctx.add(key + self.name, getattr(outputs, key)) + ctx.add(key + self.name, getattr(outputs, key) if n > 1 else getattr(outputs, key)[np.newaxis,...]) except AttributeError: try: ctx.add(key + self.name, errors[key]) diff --git a/src/py21cmmc/likelihood.py b/src/py21cmmc/likelihood.py index fc1603d2..3f001718 100644 --- a/src/py21cmmc/likelihood.py +++ b/src/py21cmmc/likelihood.py @@ -1436,8 +1436,8 @@ def reduce_data(self, ctx): def computeLikelihood(self, model): """Compute the likelihood.""" - n = model["xHI"].shape[0] xHI = np.atleast_2d(model["xHI"]) + n = xHI.shape[0] lnprob = np.zeros(n) for i in range(n): if self._require_spline: @@ -1462,7 +1462,7 @@ def computeLikelihood(self, model): lnprob[i] += self.lnprob(model_spline(z), data, sigma_t) logger.debug(f"Neutral fraction Likelihood computed: {lnprob}") - return lnprob + return lnprob.squeeze() def lnprob(self, model, data, sigma): """Compute the log prob given a model, data and error."""