Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MonteCarlo to be a StatefulInterpretation #369

Merged
merged 7 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import funsor
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo
from funsor.montecarlo import MonteCarlo


def main(args):
Expand All @@ -37,7 +37,7 @@ def guide(data):

# Because the API in minipyro matches that of Pyro proper,
# training code works with generic Pyro implementations.
with pyro_backend(args.backend), interpretation(monte_carlo):
with pyro_backend(args.backend), interpretation(MonteCarlo()):
# Construct an SVI object so we can do variational inference on our
# model/guide pair.
Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main(args):
encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder)
decode = funsor.function(Reals[20], Reals[28, 28])(decoder)

@funsor.interpreter.interpretation(funsor.montecarlo.monte_carlo)
@funsor.interpretation(funsor.montecarlo.MonteCarlo())
def loss_function(data, subsample_scale):
# Lazily sample from the guide.
loc, scale = encode(data)
Expand Down
48 changes: 47 additions & 1 deletion funsor/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import re
import types
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from contextlib import contextmanager
from functools import singledispatch

Expand Down Expand Up @@ -111,6 +111,7 @@ def interpretation(new):
assert callable(new)
global _INTERPRETATION
old = _INTERPRETATION
new = InterpreterStack(new, old)
try:
_INTERPRETATION = new
yield
Expand Down Expand Up @@ -311,6 +312,14 @@ def reinterpret(x):
return recursion_reinterpret(x)


class InterpreterStack(namedtuple("InterpreterStack", ["default", "fallback"])):
def __call__(self, cls, *args):
for interpreter in self:
result = interpreter(cls, *args)
if result is not None:
return result


def dispatched_interpretation(fn):
"""
Decorator to create a dispatched interpretation function.
Expand All @@ -324,13 +333,50 @@ def dispatched_interpretation(fn):
return fn


class StatefulInterpretation:
"""
Base class for interpreters with instance-dependent state or parameters.

Example usage::

class MyInterpretation(StatefulInterpretation):

def __init__(self, my_param):
self.my_param = my_param

@MyInterpretation.register(...)
def my_impl(interpreter_state, cls, *args):
my_param = interpreter_state.my_param
...

with interpretation(MyInterpretation(my_param=0.1)):
...
"""
def __init_subclass__(cls):
cls.registry = KeyedRegistry(default=lambda *args: None)
cls.dispatch = cls.registry.dispatch

def __call__(self, cls, *args):
return self.dispatch(cls, *args)(self, *args)
Comment on lines +363 to +364
Copy link
Member Author

@fritzo fritzo Sep 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default implementation of __call__() is a partial function, and relies on InterpreterStack to fall back to the previous interpretation, which is by assumption complete.


if _DEBUG:
@classmethod
def register(cls, *args):
return lambda fn: cls.registry.register(*args)(debug_logged(fn))
else:
@classmethod
def register(cls, *args):
return cls.registry.register(*args)


class PatternMissingError(NotImplementedError):
def __str__(self):
return "{}\nThis is most likely due to a missing pattern.".format(super().__str__())


__all__ = [
'PatternMissingError',
'StatefulInterpretation',
'dispatched_interpretation',
'interpret',
'interpretation',
Expand Down
2 changes: 1 addition & 1 deletion funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def __call__(self, model, guide, *args, **kwargs):
# This is a wrapper for compatibility with full Pyro.
class Trace_ELBO(ELBO):
def __call__(self, model, guide, *args, **kwargs):
with funsor.montecarlo.monte_carlo_interpretation():
with funsor.interpretation(funsor.montecarlo.MonteCarlo()):
return elbo(model, guide, *args, **kwargs)


Expand Down
53 changes: 15 additions & 38 deletions funsor/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,34 @@
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from contextlib import contextmanager

from funsor.integrate import Integrate
from funsor.interpreter import dispatched_interpretation, interpretation
from funsor.terms import Funsor, eager
from funsor.interpreter import StatefulInterpretation
from funsor.terms import Funsor


@dispatched_interpretation
def monte_carlo(cls, *args):
class MonteCarlo(StatefulInterpretation):
"""
A Monte Carlo interpretation of :class:`~funsor.integrate.Integrate`
expressions. This falls back to :class:`~funsor.terms.eager` in other
cases.
"""
# TODO Memoize sample statements in a context manager.
result = monte_carlo.dispatch(cls, *args)(*args)
if result is None:
result = eager(cls, *args)
return result

expressions. This falls back to the previous interpreter in other cases.

# This is a globally configurable parameter to draw multiple samples.
monte_carlo.sample_inputs = OrderedDict()
:param rng_key:
"""
def __init__(self, *, rng_key=None, **sample_inputs):
self.rng_key = rng_key
fritzo marked this conversation as resolved.
Show resolved Hide resolved
self.sample_inputs = OrderedDict(sample_inputs)


@contextmanager
def monte_carlo_interpretation(**sample_inputs):
"""
Context manager to set ``monte_carlo.sample_inputs`` and
install the :func:`monte_carlo` interpretation.
"""
old = monte_carlo.sample_inputs
monte_carlo.sample_inputs = OrderedDict(sample_inputs)
try:
with interpretation(monte_carlo):
yield
finally:
monte_carlo.sample_inputs = old


@monte_carlo.register(Integrate, Funsor, Funsor, frozenset)
def monte_carlo_integrate(log_measure, integrand, reduced_vars):
# FIXME: how to pass rng_key to here?
sample = log_measure.sample(reduced_vars, monte_carlo.sample_inputs)
@MonteCarlo.register(Integrate, Funsor, Funsor, frozenset)
def monte_carlo_integrate(state, log_measure, integrand, reduced_vars):
# FIXME use state.rng_key to here
fritzo marked this conversation as resolved.
Show resolved Hide resolved
sample = log_measure.sample(reduced_vars, state.sample_inputs)
if sample is log_measure:
return None # cannot progress
reduced_vars |= frozenset(monte_carlo.sample_inputs).intersection(sample.inputs)
reduced_vars |= frozenset(state.sample_inputs).intersection(sample.inputs)
return Integrate(sample, integrand, reduced_vars)


__all__ = [
'monte_carlo',
'monte_carlo_interpretation'
'MonteCarlo',
]
6 changes: 3 additions & 3 deletions test/examples/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.gaussian import Gaussian
from funsor.integrate import Integrate
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo
from funsor.montecarlo import MonteCarlo
from funsor.pyro.convert import AffineNormal
from funsor.sum_product import MarkovProduct
from funsor.tensor import Function, Tensor
Expand Down Expand Up @@ -201,12 +201,12 @@ def test_bart(analytic_kl):

if analytic_kl:
exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
with interpretation(monte_carlo):
with interpretation(MonteCarlo()):
approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
elbo = exact_part + approx_part
else:
p = p_prior + p_likelihood
with interpretation(monte_carlo):
with interpretation(MonteCarlo()):
elbo = Integrate(q, p - q, "gate_rate_t")

assert isinstance(elbo, Tensor), elbo.pretty()
Expand Down
6 changes: 2 additions & 4 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
from funsor.domains import Bint
from funsor.integrate import Integrate
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo
from funsor.montecarlo import MonteCarlo
from funsor.terms import Variable, eager, lazy, moment_matching, normalize, reflect
from funsor.testing import assert_close, random_tensor
from funsor.util import get_backend


@pytest.mark.parametrize('interp', [
reflect, lazy, normalize, eager, moment_matching,
pytest.param(monte_carlo, marks=pytest.mark.xfail(
get_backend() == "jax", reason="Lacking pattern to pass rng_key"))
MonteCarlo(),
])
def test_integrate(interp):
log_measure = random_tensor(OrderedDict([('i', Bint[2]), ('j', Bint[3])]))
Expand Down
4 changes: 2 additions & 2 deletions test/test_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.gaussian import Gaussian
from funsor.integrate import Integrate
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo_interpretation
from funsor.montecarlo import MonteCarlo
from funsor.tensor import Tensor, numeric_array
from funsor.terms import Number, Variable, eager, moment_matching
from funsor.testing import (assert_close, randn, random_gaussian, random_tensor,
Expand Down Expand Up @@ -316,7 +316,7 @@ def test_reduce_moment_matching_moments():
[('i', Bint[2]), ('j', Bint[3]), ('x', Reals[2])]))
with interpretation(moment_matching):
approx = gaussian.reduce(ops.logaddexp, 'j')
with monte_carlo_interpretation(s=Bint[100000]):
with interpretation(MonteCarlo(s=Bint[100000])):
actual = Integrate(approx, Number(1.), 'x')
expected = Integrate(gaussian, Number(1.), {'j', 'x'})
assert_close(actual, expected, atol=1e-3, rtol=1e-3)
Expand Down
5 changes: 3 additions & 2 deletions test/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND
from funsor.domains import Bint, Real, Reals
from funsor.integrate import Integrate
from funsor.montecarlo import monte_carlo_interpretation
from funsor.interpreter import interpretation
from funsor.montecarlo import MonteCarlo
from funsor.tensor import Tensor, align_tensors
from funsor.terms import Variable
from funsor.testing import assert_close, id_from_inputs, randn, random_gaussian, random_tensor, xfail_if_not_implemented
Expand Down Expand Up @@ -341,7 +342,7 @@ def test_lognormal_distribution(moment):

log_measure = dist.LogNormal(loc, scale)(value='x')
probe = Variable('x', Real) ** moment
with monte_carlo_interpretation(particle=Bint[num_samples]):
with interpretation(MonteCarlo(particle=Bint[num_samples])):
with xfail_if_not_implemented():
actual = Integrate(log_measure, probe, frozenset(['x']))

Expand Down