diff --git a/src/mrvi/_module.py b/src/mrvi/_module.py index 82da02a..c405ff9 100644 --- a/src/mrvi/_module.py +++ b/src/mrvi/_module.py @@ -407,18 +407,12 @@ def loss( inference_outputs["qu"], generative_outputs["pu"] ).sum(-1) inference_outputs["qeps"] + + kl_z = 0.0 eps = inference_outputs["z"] - inference_outputs["z_base"] if self.z_u_prior: peps = dist.Normal(0, jnp.exp(self.pz_scale)) kl_z = -peps.log_prob(eps).sum(-1) - else: - kl_z = ( - -dist.Normal(inference_outputs["z_base"], jnp.exp(self.z_u_prior_scale)) - .log_prob(inference_outputs["z"]) - .sum(-1) - if self.z_u_prior_scale is not None - else 0 - ) weighted_kl_local = kl_weight * (kl_u + kl_z) loss = reconstruction_loss + weighted_kl_local