Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MonteCarlo to be a StatefulInterpretation #369

Merged
merged 7 commits into from
Sep 28, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Sep 27, 2020

Pair coded with @eb8680

@eb8680 I have adapted StatefuleInterpretation from https://github.com/pyro-ppl/funsor/tree/adam-interpreter, making it a namedtuple as you had originally suggested, and turning on fallback logic by default as you suggested.

This PR:

  • Adds a StatefulInterpretation base class for interpretations that need to pass instance-specific data such as num_samples and rng_key.
  • Adds an InterpreterStack and makes the MonteCarlo interpretation fall back not to eager but the previous interpretation when it was installed.
  • Refactors monte_carlo to a stateful MonteCarlo.
  • Fixes the JAX implementation of MonteCarlo by tracking rng_key
    @fehiepsi can you help me out with this?

@fritzo fritzo added the WIP label Sep 27, 2020
Comment on lines +359 to +360
def __call__(self, cls, *args):
return self.dispatch(cls, *args)(self, *args)
Copy link
Member Author

@fritzo fritzo Sep 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default implementation of __call__() is a partial function, and relies on InterpreterStack to fall back to the previous interpretation, which is by assumption complete.

funsor/montecarlo.py Outdated Show resolved Hide resolved
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, MonteCarlo is much cleaner.

@@ -324,13 +333,54 @@ def dispatched_interpretation(fn):
return fn


class StatefulInterpretationMeta(type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is required for Python 3.5? Once we drop Python 3.5 support, we should remove this and other metaclasses where possible in favor of __init_subclass__.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I was almost tempted to drop support for Python 3.5 in this PR, but that seemed like overkill.

@eb8680 eb8680 merged commit 6542efa into master Sep 28, 2020
@eb8680 eb8680 deleted the stateful-interpreter branch September 28, 2020 15:13
@fritzo
Copy link
Member Author

fritzo commented Sep 28, 2020

@eb8680 Thanks for reviewing!
@fehiepsi Thanks for helping!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants