-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Trace_ELBO
that generalizes Trace_ELBO
, TraceEnum_ELBO
, and TraceGraph_ELBO
#2893
base: dev
Are you sure you want to change the base?
Conversation
|
||
elbo = to_funsor(0.0) | ||
for cost in costs: | ||
elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this missing Dice factors included in log_measures
? IIRC that was the reason for using Integrate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied the test from #2894 which has a simple model/guide pair. When running that model (Elbo=Trace_ELBO, backend=contrib.funsor, reparam-False
) both guide_terms["log_measures"]
and model_terms["log_measures"]
are empty. I can't find Dice factors anywhere in model_terms
or guide_terms
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess they're not included because Funsor.sample
isn't used in the evaluation of Trace_ELBO
. I don't think contrib.funsor.infer.Trace_ELBO
is tested extensively outside the pyro-api
tests in tests/contrib/funsor/test_pyroapi_funsor.py
, which is why this wasn't noticed before.
A more general Funsor-based implementation of Trace_ELBO
is certainly possible and would look very similar to the guide-side enumeration handling logic in TraceEnum_ELBO
. We might even be able to write a custom "enumeration" strategy that just called Funsor.sample
and reuse TraceEnum_ELBO
as the Trace_ELBO
implementation.
I believe a completely general version might require variable elimination logic beyond what's currently in funsor.sum_product
handling cases where the guide had plate structure incompatible with the restrictions there, although I can't immediately think of existing tests or examples where that would be the case.
- df_a * logqa | ||
- df_a * (qb * logqb).sum() | ||
- df_a * (qb * df_c * logqc).sum() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 can you check the math here? This is an example with b
enumerated in the guide. Trace_ELBO
works correctly here.
# +-----------+ | ||
# a -|-> b --> c | | ||
# | \--> d | | ||
# +-----------+ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e
is observed and b
is enumerated.
# guide (c is enumerated) | ||
# +-----------+ | ||
# a -|-> b --> c | | ||
# +-----------+ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d
is observed and c
is enumerated.
# guide (b is enumerated) | ||
# +-----------+ | ||
# a -|-> b --> c | | ||
# +-----------+ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d
is observed and b
is enumerated.
Design Doc
New version of
Trace_ELBO
that extendsTraceEnum_ELBO
:dice_factor
s (as importance weights) (Dispatch toIntegrate(Delta, ...)
innormalize_integrate_contraction
funsor#551)I get wrong values forelbo
(much larger absolute value compared topyro.infer.trace_elbo.Trace_ELBO
), presumably becausesum(log_factors, to_funsor(0.0))
in line 41 broadcasts terms inlog_factors
and then that leads to large absolute values afterelbo.reduce(funsor.ops.add, plate_vars)
summation.Here I try to fix it by reducing each cost term individually similar toTraceEnum_ELBO
. I'm also not sure if integration is needed here.