Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MonteCarlo to be a StatefulInterpretation #369

Merged
merged 7 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/eeg_slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
We use a switching linear dynamical system [1] to model a EEG time series dataset.
For inference we use a moment-matching approximation enabled by
`funsor.interpreter.interpretation(funsor.terms.moment_matching)`.
`funsor.interpretation(funsor.terms.moment_matching)`.

References

Expand Down Expand Up @@ -105,7 +105,7 @@ def get_tensors_and_dists(self):
return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist

# compute the marginal log probability of the observed data using a moment-matching approximation
@funsor.interpreter.interpretation(funsor.terms.moment_matching)
@funsor.interpretation(funsor.terms.moment_matching)
def log_prob(self, data):
trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists()

Expand Down Expand Up @@ -153,7 +153,7 @@ def log_prob(self, data):
# here we implicitly use a moment matching lag of L = 1. the general logic follows
# the logic in the log_prob method.
@torch.no_grad()
@funsor.interpreter.interpretation(funsor.terms.moment_matching)
@funsor.interpretation(funsor.terms.moment_matching)
def filter_and_predict(self, data, smoothing=False):
trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists()

Expand Down
4 changes: 2 additions & 2 deletions examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import funsor
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo
from funsor.montecarlo import MonteCarlo


def main(args):
Expand All @@ -37,7 +37,7 @@ def guide(data):

# Because the API in minipyro matches that of Pyro proper,
# training code works with generic Pyro implementations.
with pyro_backend(args.backend), interpretation(monte_carlo):
with pyro_backend(args.backend), interpretation(MonteCarlo()):
# Construct an SVI object so we can do variational inference on our
# model/guide pair.
Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
Expand Down
2 changes: 1 addition & 1 deletion examples/slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def main(args):
emit_noise.data]

# A Gaussian HMM model.
@funsor.interpreter.interpretation(funsor.terms.moment_matching)
@funsor.interpretation(funsor.terms.moment_matching)
def model(data):
log_prob = funsor.Number(0.)

Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main(args):
encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder)
decode = funsor.function(Reals[20], Reals[28, 28])(decoder)

@funsor.interpreter.interpretation(funsor.montecarlo.monte_carlo)
@funsor.interpretation(funsor.montecarlo.MonteCarlo())
def loss_function(data, subsample_scale):
# Lazily sample from the guide.
loc, scale = encode(data)
Expand Down
3 changes: 2 additions & 1 deletion funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals
from funsor.integrate import Integrate
from funsor.interpreter import reinterpret
from funsor.interpreter import reinterpret, interpretation
from funsor.sum_product import MarkovProduct
from funsor.tensor import Tensor, function
from funsor.terms import (
Expand Down Expand Up @@ -74,6 +74,7 @@
'gaussian',
'get_backend',
'integrate',
'interpretation',
'interpreter',
'joint',
'memoize',
Expand Down
4 changes: 2 additions & 2 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):

if self.red_op in (ops.logaddexp, nullop):
if self.bin_op in (ops.nullop, ops.logaddexp):
if rng_key is not None:
if rng_key is not None and get_backend() == "jax":
import jax
rng_keys = jax.random.split(rng_key, len(self.terms))
else:
Expand All @@ -100,7 +100,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
return Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)

if self.bin_op is ops.add:
if rng_key is not None:
if rng_key is not None and get_backend() == "jax":
import jax
rng_keys = jax.random.split(rng_key)
else:
Expand Down
52 changes: 51 additions & 1 deletion funsor/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import re
import types
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from contextlib import contextmanager
from functools import singledispatch

Expand Down Expand Up @@ -111,6 +111,7 @@ def interpretation(new):
assert callable(new)
global _INTERPRETATION
old = _INTERPRETATION
new = InterpreterStack(new, old)
try:
_INTERPRETATION = new
yield
Expand Down Expand Up @@ -311,6 +312,14 @@ def reinterpret(x):
return recursion_reinterpret(x)


class InterpreterStack(namedtuple("InterpreterStack", ["default", "fallback"])):
def __call__(self, cls, *args):
for interpreter in self:
result = interpreter(cls, *args)
if result is not None:
return result


def dispatched_interpretation(fn):
"""
Decorator to create a dispatched interpretation function.
Expand All @@ -324,13 +333,54 @@ def dispatched_interpretation(fn):
return fn


class StatefulInterpretationMeta(type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is required for Python 3.5? Once we drop Python 3.5 support, we should remove this and other metaclasses where possible in favor of __init_subclass__.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I was almost tempted to drop support for Python 3.5 in this PR, but that seemed like overkill.

def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)
cls.registry = KeyedRegistry(default=lambda *args: None)
cls.dispatch = cls.registry.dispatch


class StatefulInterpretation(metaclass=StatefulInterpretationMeta):
"""
Base class for interpreters with instance-dependent state or parameters.

Example usage::

class MyInterpretation(StatefulInterpretation):

def __init__(self, my_param):
self.my_param = my_param

@MyInterpretation.register(...)
def my_impl(interpreter_state, cls, *args):
my_param = interpreter_state.my_param
...

with interpretation(MyInterpretation(my_param=0.1)):
...
"""

def __call__(self, cls, *args):
return self.dispatch(cls, *args)(self, *args)
Comment on lines +363 to +364
Copy link
Member Author

@fritzo fritzo Sep 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default implementation of __call__() is a partial function, and relies on InterpreterStack to fall back to the previous interpretation, which is by assumption complete.


if _DEBUG:
@classmethod
def register(cls, *args):
return lambda fn: cls.registry.register(*args)(debug_logged(fn))
else:
@classmethod
def register(cls, *args):
return cls.registry.register(*args)


class PatternMissingError(NotImplementedError):
def __str__(self):
return "{}\nThis is most likely due to a missing pattern.".format(super().__str__())


__all__ = [
'PatternMissingError',
'StatefulInterpretation',
'dispatched_interpretation',
'interpret',
'interpretation',
Expand Down
6 changes: 3 additions & 3 deletions funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def log_prob(self, value):

# Draw a sample.
def __call__(self):
with funsor.interpreter.interpretation(funsor.terms.eager):
with funsor.interpretation(funsor.terms.eager):
dist = self.funsor_dist(value='value')
delta = dist.sample(frozenset(['value']), sample_inputs=self.sample_inputs)
if isinstance(delta, funsor.cnf.Contraction):
Expand Down Expand Up @@ -508,7 +508,7 @@ def __call__(self, model, guide, *args, **kwargs):
# This is a wrapper for compatibility with full Pyro.
class Trace_ELBO(ELBO):
def __call__(self, model, guide, *args, **kwargs):
with funsor.montecarlo.monte_carlo_interpretation():
with funsor.interpretation(funsor.montecarlo.MonteCarlo()):
return elbo(model, guide, *args, **kwargs)


Expand All @@ -521,7 +521,7 @@ class TraceEnum_ELBO(ELBO):
# TODO allow mixing of sampling and exact integration
def __call__(self, model, guide, *args, **kwargs):
if self.options.get("optimize", None):
with funsor.interpreter.interpretation(funsor.optimizer.optimize):
with funsor.interpretation(funsor.optimizer.optimize):
elbo_expr = elbo(model, guide, *args, **kwargs)
return funsor.reinterpret(elbo_expr)
return elbo(model, guide, *args, **kwargs)
Expand Down
55 changes: 19 additions & 36 deletions funsor/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,40 @@
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from contextlib import contextmanager

from funsor.integrate import Integrate
from funsor.interpreter import dispatched_interpretation, interpretation
from funsor.terms import Funsor, eager
from funsor.interpreter import StatefulInterpretation
from funsor.terms import Funsor
from funsor.util import get_backend


@dispatched_interpretation
def monte_carlo(cls, *args):
class MonteCarlo(StatefulInterpretation):
"""
A Monte Carlo interpretation of :class:`~funsor.integrate.Integrate`
expressions. This falls back to :class:`~funsor.terms.eager` in other
cases.
expressions. This falls back to the previous interpreter in other cases.

:param rng_key:
"""
# TODO Memoize sample statements in a context manager.
result = monte_carlo.dispatch(cls, *args)(*args)
if result is None:
result = eager(cls, *args)
return result
def __init__(self, *, rng_key=None, **sample_inputs):
self.rng_key = rng_key
fritzo marked this conversation as resolved.
Show resolved Hide resolved
self.sample_inputs = OrderedDict(sample_inputs)


# This is a globally configurable parameter to draw multiple samples.
monte_carlo.sample_inputs = OrderedDict()
@MonteCarlo.register(Integrate, Funsor, Funsor, frozenset)
def monte_carlo_integrate(state, log_measure, integrand, reduced_vars):
sample_options = {}
if state.rng_key is not None and get_backend() == "jax":
import jax

sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key)

@contextmanager
def monte_carlo_interpretation(**sample_inputs):
"""
Context manager to set ``monte_carlo.sample_inputs`` and
install the :func:`monte_carlo` interpretation.
"""
old = monte_carlo.sample_inputs
monte_carlo.sample_inputs = OrderedDict(sample_inputs)
try:
with interpretation(monte_carlo):
yield
finally:
monte_carlo.sample_inputs = old


@monte_carlo.register(Integrate, Funsor, Funsor, frozenset)
def monte_carlo_integrate(log_measure, integrand, reduced_vars):
# FIXME: how to pass rng_key to here?
sample = log_measure.sample(reduced_vars, monte_carlo.sample_inputs)
sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options)
if sample is log_measure:
return None # cannot progress
reduced_vars |= frozenset(monte_carlo.sample_inputs).intersection(sample.inputs)
reduced_vars |= frozenset(state.sample_inputs).intersection(sample.inputs)
return Integrate(sample, integrand, reduced_vars)


__all__ = [
'monte_carlo',
'monte_carlo_interpretation'
'MonteCarlo',
]
6 changes: 3 additions & 3 deletions test/examples/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.gaussian import Gaussian
from funsor.integrate import Integrate
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo
from funsor.montecarlo import MonteCarlo
from funsor.pyro.convert import AffineNormal
from funsor.sum_product import MarkovProduct
from funsor.tensor import Function, Tensor
Expand Down Expand Up @@ -201,12 +201,12 @@ def test_bart(analytic_kl):

if analytic_kl:
exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
with interpretation(monte_carlo):
with interpretation(MonteCarlo()):
approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
elbo = exact_part + approx_part
else:
p = p_prior + p_likelihood
with interpretation(monte_carlo):
with interpretation(MonteCarlo()):
elbo = Integrate(q, p - q, "gate_rate_t")

assert isinstance(elbo, Tensor), elbo.pretty()
Expand Down
7 changes: 3 additions & 4 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@

from collections import OrderedDict

import numpy as np
import pytest

from funsor import ops
from funsor.domains import Bint
from funsor.integrate import Integrate
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo
from funsor.montecarlo import MonteCarlo
from funsor.terms import Variable, eager, lazy, moment_matching, normalize, reflect
from funsor.testing import assert_close, random_tensor
from funsor.util import get_backend


@pytest.mark.parametrize('interp', [
reflect, lazy, normalize, eager, moment_matching,
pytest.param(monte_carlo, marks=pytest.mark.xfail(
get_backend() == "jax", reason="Lacking pattern to pass rng_key"))
MonteCarlo(rng_key=np.array([0, 0], dtype=np.uint32)),
])
def test_integrate(interp):
log_measure = random_tensor(OrderedDict([('i', Bint[2]), ('j', Bint[3])]))
Expand Down
4 changes: 2 additions & 2 deletions test/test_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.gaussian import Gaussian
from funsor.integrate import Integrate
from funsor.interpreter import interpretation
from funsor.montecarlo import monte_carlo_interpretation
from funsor.montecarlo import MonteCarlo
from funsor.tensor import Tensor, numeric_array
from funsor.terms import Number, Variable, eager, moment_matching
from funsor.testing import (assert_close, randn, random_gaussian, random_tensor,
Expand Down Expand Up @@ -316,7 +316,7 @@ def test_reduce_moment_matching_moments():
[('i', Bint[2]), ('j', Bint[3]), ('x', Reals[2])]))
with interpretation(moment_matching):
approx = gaussian.reduce(ops.logaddexp, 'j')
with monte_carlo_interpretation(s=Bint[100000]):
with interpretation(MonteCarlo(s=Bint[100000])):
actual = Integrate(approx, Number(1.), 'x')
expected = Integrate(gaussian, Number(1.), {'j', 'x'})
assert_close(actual, expected, atol=1e-3, rtol=1e-3)
Expand Down
5 changes: 3 additions & 2 deletions test/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND
from funsor.domains import Bint, Real, Reals
from funsor.integrate import Integrate
from funsor.montecarlo import monte_carlo_interpretation
from funsor.interpreter import interpretation
from funsor.montecarlo import MonteCarlo
from funsor.tensor import Tensor, align_tensors
from funsor.terms import Variable
from funsor.testing import assert_close, id_from_inputs, randn, random_gaussian, random_tensor, xfail_if_not_implemented
Expand Down Expand Up @@ -341,7 +342,7 @@ def test_lognormal_distribution(moment):

log_measure = dist.LogNormal(loc, scale)(value='x')
probe = Variable('x', Real) ** moment
with monte_carlo_interpretation(particle=Bint[num_samples]):
with interpretation(MonteCarlo(particle=Bint[num_samples])):
with xfail_if_not_implemented():
actual = Integrate(log_measure, probe, frozenset(['x']))

Expand Down