Skip to content

Commit

Permalink
change log prob to logdensity
Browse files Browse the repository at this point in the history
  • Loading branch information
xidulu committed Jan 13, 2023
1 parent 2c1a31e commit 439be68
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions blackjax/vi/meanfield_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from blackjax.types import PRNGKey, PyTree

__all__ = ["MFVIState", "MFVIInfo", "sample", "logprob", "step"]
__all__ = ["MFVIState", "MFVIInfo", "sample", "generate_meanfield_logdensity_fn", "step"]


class MFVIState(NamedTuple):
Expand Down Expand Up @@ -93,7 +93,7 @@ def kl_divergence_fn(parameters):
if stl_estimator:
mu = jax.lax.stop_gradient(mu)
rho = jax.lax.stop_gradient(rho)
logq = jax.vmap(generate_meanfield_logprob(mu, rho))(z)
logq = jax.vmap(generate_meanfield_logdensity_fn(mu, rho))(z)
logp = jax.vmap(logdensity_fn)(z)
return (logq - logp).mean()

Expand All @@ -120,12 +120,12 @@ def _sample(rng_key, mu, rho, num_samples):
return jax.vmap(unravel_fn)(flatten_sample)


def generate_meanfield_logprob(mu, rho):
def generate_meanfield_logdensity_fn(mu, rho):
sigma_param = jax.tree_map(jnp.exp, rho)

def meanfield_logprob(position):
def meanfield_logdensity_fn(position):
logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param)
logq = jax.tree_map(jnp.sum, logq_pytree)
return jax.tree_util.tree_reduce(jnp.add, logq)

return meanfield_logprob
return meanfield_logdensity_fn

0 comments on commit 439be68

Please sign in to comment.