Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
Merge pull request #91 from YosefLab/jhong/zuprior
Browse files Browse the repository at this point in the history
z_u_prior false should remove the prior
  • Loading branch information
justjhong authored Mar 18, 2024
2 parents ed0b8f9 + d773a07 commit 6150961
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions src/mrvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,12 @@ def loss(
inference_outputs["qu"], generative_outputs["pu"]
).sum(-1)
inference_outputs["qeps"]

kl_z = 0.0
eps = inference_outputs["z"] - inference_outputs["z_base"]
if self.z_u_prior:
peps = dist.Normal(0, jnp.exp(self.pz_scale))
kl_z = -peps.log_prob(eps).sum(-1)
else:
kl_z = (
-dist.Normal(inference_outputs["z_base"], jnp.exp(self.z_u_prior_scale))
.log_prob(inference_outputs["z"])
.sum(-1)
if self.z_u_prior_scale is not None
else 0
)

weighted_kl_local = kl_weight * (kl_u + kl_z)
loss = reconstruction_loss + weighted_kl_local
Expand Down

0 comments on commit 6150961

Please sign in to comment.