diff --git a/examples/eeg_slds.py b/examples/eeg_slds.py index 37cd67bd0..a9dd65f2c 100644 --- a/examples/eeg_slds.py +++ b/examples/eeg_slds.py @@ -4,7 +4,7 @@ """ We use a switching linear dynamical system [1] to model a EEG time series dataset. For inference we use a moment-matching approximation enabled by -`funsor.interpreter.interpretation(funsor.terms.moment_matching)`. +`funsor.interpretation(funsor.terms.moment_matching)`. References @@ -105,7 +105,7 @@ def get_tensors_and_dists(self): return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist # compute the marginal log probability of the observed data using a moment-matching approximation - @funsor.interpreter.interpretation(funsor.terms.moment_matching) + @funsor.interpretation(funsor.terms.moment_matching) def log_prob(self, data): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists() @@ -153,7 +153,7 @@ def log_prob(self, data): # here we implicitly use a moment matching lag of L = 1. the general logic follows # the logic in the log_prob method. @torch.no_grad() - @funsor.interpreter.interpretation(funsor.terms.moment_matching) + @funsor.interpretation(funsor.terms.moment_matching) def filter_and_predict(self, data, smoothing=False): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists() diff --git a/examples/minipyro.py b/examples/minipyro.py index 16779761c..8665af35e 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -10,7 +10,7 @@ import funsor from funsor.interpreter import interpretation -from funsor.montecarlo import monte_carlo +from funsor.montecarlo import MonteCarlo def main(args): @@ -37,7 +37,7 @@ def guide(data): # Because the API in minipyro matches that of Pyro proper, # training code works with generic Pyro implementations. - with pyro_backend(args.backend), interpretation(monte_carlo): + with pyro_backend(args.backend), interpretation(MonteCarlo()): # Construct an SVI object so we can do variational inference on our # model/guide pair. Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO diff --git a/examples/slds.py b/examples/slds.py index 70203c2a0..9ade1fd88 100644 --- a/examples/slds.py +++ b/examples/slds.py @@ -26,7 +26,7 @@ def main(args): emit_noise.data] # A Gaussian HMM model. - @funsor.interpreter.interpretation(funsor.terms.moment_matching) + @funsor.interpretation(funsor.terms.moment_matching) def model(data): log_prob = funsor.Number(0.) diff --git a/examples/vae.py b/examples/vae.py index b1d8a1f5f..e91b74e47 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -56,7 +56,7 @@ def main(args): encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder) decode = funsor.function(Reals[20], Reals[28, 28])(decoder) - @funsor.interpreter.interpretation(funsor.montecarlo.monte_carlo) + @funsor.interpretation(funsor.montecarlo.MonteCarlo()) def loss_function(data, subsample_scale): # Lazily sample from the guide. loc, scale = encode(data) diff --git a/funsor/__init__.py b/funsor/__init__.py index 3d521a330..6b2c1e4f1 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -3,7 +3,7 @@ from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals from funsor.integrate import Integrate -from funsor.interpreter import reinterpret +from funsor.interpreter import reinterpret, interpretation from funsor.sum_product import MarkovProduct from funsor.tensor import Tensor, function from funsor.terms import ( @@ -74,6 +74,7 @@ 'gaussian', 'get_backend', 'integrate', + 'interpretation', 'interpreter', 'joint', 'memoize', diff --git a/funsor/cnf.py b/funsor/cnf.py index d2b6f9879..15ac6fbae 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -86,7 +86,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): if self.red_op in (ops.logaddexp, nullop): if self.bin_op in (ops.nullop, ops.logaddexp): - if rng_key is not None: + if rng_key is not None and get_backend() == "jax": import jax rng_keys = jax.random.split(rng_key, len(self.terms)) else: @@ -100,7 +100,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): return Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms) if self.bin_op is ops.add: - if rng_key is not None: + if rng_key is not None and get_backend() == "jax": import jax rng_keys = jax.random.split(rng_key) else: diff --git a/funsor/interpreter.py b/funsor/interpreter.py index dcde198f5..bc9e5b5d6 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -6,7 +6,7 @@ import os import re import types -from collections import OrderedDict +from collections import OrderedDict, namedtuple from contextlib import contextmanager from functools import singledispatch @@ -111,6 +111,7 @@ def interpretation(new): assert callable(new) global _INTERPRETATION old = _INTERPRETATION + new = InterpreterStack(new, old) try: _INTERPRETATION = new yield @@ -311,6 +312,14 @@ def reinterpret(x): return recursion_reinterpret(x) +class InterpreterStack(namedtuple("InterpreterStack", ["default", "fallback"])): + def __call__(self, cls, *args): + for interpreter in self: + result = interpreter(cls, *args) + if result is not None: + return result + + def dispatched_interpretation(fn): """ Decorator to create a dispatched interpretation function. @@ -324,6 +333,46 @@ def dispatched_interpretation(fn): return fn +class StatefulInterpretationMeta(type): + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + cls.registry = KeyedRegistry(default=lambda *args: None) + cls.dispatch = cls.registry.dispatch + + +class StatefulInterpretation(metaclass=StatefulInterpretationMeta): + """ + Base class for interpreters with instance-dependent state or parameters. + + Example usage:: + + class MyInterpretation(StatefulInterpretation): + + def __init__(self, my_param): + self.my_param = my_param + + @MyInterpretation.register(...) + def my_impl(interpreter_state, cls, *args): + my_param = interpreter_state.my_param + ... + + with interpretation(MyInterpretation(my_param=0.1)): + ... + """ + + def __call__(self, cls, *args): + return self.dispatch(cls, *args)(self, *args) + + if _DEBUG: + @classmethod + def register(cls, *args): + return lambda fn: cls.registry.register(*args)(debug_logged(fn)) + else: + @classmethod + def register(cls, *args): + return cls.registry.register(*args) + + class PatternMissingError(NotImplementedError): def __str__(self): return "{}\nThis is most likely due to a missing pattern.".format(super().__str__()) @@ -331,6 +380,7 @@ def __str__(self): __all__ = [ 'PatternMissingError', + 'StatefulInterpretation', 'dispatched_interpretation', 'interpret', 'interpretation', diff --git a/funsor/minipyro.py b/funsor/minipyro.py index 35371381c..caa6ecf4b 100644 --- a/funsor/minipyro.py +++ b/funsor/minipyro.py @@ -49,7 +49,7 @@ def log_prob(self, value): # Draw a sample. def __call__(self): - with funsor.interpreter.interpretation(funsor.terms.eager): + with funsor.interpretation(funsor.terms.eager): dist = self.funsor_dist(value='value') delta = dist.sample(frozenset(['value']), sample_inputs=self.sample_inputs) if isinstance(delta, funsor.cnf.Contraction): @@ -508,7 +508,7 @@ def __call__(self, model, guide, *args, **kwargs): # This is a wrapper for compatibility with full Pyro. class Trace_ELBO(ELBO): def __call__(self, model, guide, *args, **kwargs): - with funsor.montecarlo.monte_carlo_interpretation(): + with funsor.interpretation(funsor.montecarlo.MonteCarlo()): return elbo(model, guide, *args, **kwargs) @@ -521,7 +521,7 @@ class TraceEnum_ELBO(ELBO): # TODO allow mixing of sampling and exact integration def __call__(self, model, guide, *args, **kwargs): if self.options.get("optimize", None): - with funsor.interpreter.interpretation(funsor.optimizer.optimize): + with funsor.interpretation(funsor.optimizer.optimize): elbo_expr = elbo(model, guide, *args, **kwargs) return funsor.reinterpret(elbo_expr) return elbo(model, guide, *args, **kwargs) diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index a26967fb6..f5ca4c0cb 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -2,57 +2,40 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict -from contextlib import contextmanager from funsor.integrate import Integrate -from funsor.interpreter import dispatched_interpretation, interpretation -from funsor.terms import Funsor, eager +from funsor.interpreter import StatefulInterpretation +from funsor.terms import Funsor +from funsor.util import get_backend -@dispatched_interpretation -def monte_carlo(cls, *args): +class MonteCarlo(StatefulInterpretation): """ A Monte Carlo interpretation of :class:`~funsor.integrate.Integrate` - expressions. This falls back to :class:`~funsor.terms.eager` in other - cases. + expressions. This falls back to the previous interpreter in other cases. + + :param rng_key: """ - # TODO Memoize sample statements in a context manager. - result = monte_carlo.dispatch(cls, *args)(*args) - if result is None: - result = eager(cls, *args) - return result + def __init__(self, *, rng_key=None, **sample_inputs): + self.rng_key = rng_key + self.sample_inputs = OrderedDict(sample_inputs) -# This is a globally configurable parameter to draw multiple samples. -monte_carlo.sample_inputs = OrderedDict() +@MonteCarlo.register(Integrate, Funsor, Funsor, frozenset) +def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): + sample_options = {} + if state.rng_key is not None and get_backend() == "jax": + import jax + sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) -@contextmanager -def monte_carlo_interpretation(**sample_inputs): - """ - Context manager to set ``monte_carlo.sample_inputs`` and - install the :func:`monte_carlo` interpretation. - """ - old = monte_carlo.sample_inputs - monte_carlo.sample_inputs = OrderedDict(sample_inputs) - try: - with interpretation(monte_carlo): - yield - finally: - monte_carlo.sample_inputs = old - - -@monte_carlo.register(Integrate, Funsor, Funsor, frozenset) -def monte_carlo_integrate(log_measure, integrand, reduced_vars): - # FIXME: how to pass rng_key to here? - sample = log_measure.sample(reduced_vars, monte_carlo.sample_inputs) + sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options) if sample is log_measure: return None # cannot progress - reduced_vars |= frozenset(monte_carlo.sample_inputs).intersection(sample.inputs) + reduced_vars |= frozenset(state.sample_inputs).intersection(sample.inputs) return Integrate(sample, integrand, reduced_vars) __all__ = [ - 'monte_carlo', - 'monte_carlo_interpretation' + 'MonteCarlo', ] diff --git a/test/examples/test_bart.py b/test/examples/test_bart.py index fe618667f..36d4c0979 100644 --- a/test/examples/test_bart.py +++ b/test/examples/test_bart.py @@ -14,7 +14,7 @@ from funsor.gaussian import Gaussian from funsor.integrate import Integrate from funsor.interpreter import interpretation -from funsor.montecarlo import monte_carlo +from funsor.montecarlo import MonteCarlo from funsor.pyro.convert import AffineNormal from funsor.sum_product import MarkovProduct from funsor.tensor import Function, Tensor @@ -201,12 +201,12 @@ def test_bart(analytic_kl): if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") - with interpretation(monte_carlo): + with interpretation(MonteCarlo()): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood - with interpretation(monte_carlo): + with interpretation(MonteCarlo()): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() diff --git a/test/test_integrate.py b/test/test_integrate.py index 26075d0a4..f6939731d 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -3,22 +3,21 @@ from collections import OrderedDict +import numpy as np import pytest from funsor import ops from funsor.domains import Bint from funsor.integrate import Integrate from funsor.interpreter import interpretation -from funsor.montecarlo import monte_carlo +from funsor.montecarlo import MonteCarlo from funsor.terms import Variable, eager, lazy, moment_matching, normalize, reflect from funsor.testing import assert_close, random_tensor -from funsor.util import get_backend @pytest.mark.parametrize('interp', [ reflect, lazy, normalize, eager, moment_matching, - pytest.param(monte_carlo, marks=pytest.mark.xfail( - get_backend() == "jax", reason="Lacking pattern to pass rng_key")) + MonteCarlo(rng_key=np.array([0, 0], dtype=np.uint32)), ]) def test_integrate(interp): log_measure = random_tensor(OrderedDict([('i', Bint[2]), ('j', Bint[3])])) diff --git a/test/test_joint.py b/test/test_joint.py index 907008b55..83ce3889c 100644 --- a/test/test_joint.py +++ b/test/test_joint.py @@ -14,7 +14,7 @@ from funsor.gaussian import Gaussian from funsor.integrate import Integrate from funsor.interpreter import interpretation -from funsor.montecarlo import monte_carlo_interpretation +from funsor.montecarlo import MonteCarlo from funsor.tensor import Tensor, numeric_array from funsor.terms import Number, Variable, eager, moment_matching from funsor.testing import (assert_close, randn, random_gaussian, random_tensor, @@ -316,7 +316,7 @@ def test_reduce_moment_matching_moments(): [('i', Bint[2]), ('j', Bint[3]), ('x', Reals[2])])) with interpretation(moment_matching): approx = gaussian.reduce(ops.logaddexp, 'j') - with monte_carlo_interpretation(s=Bint[100000]): + with interpretation(MonteCarlo(s=Bint[100000])): actual = Integrate(approx, Number(1.), 'x') expected = Integrate(gaussian, Number(1.), {'j', 'x'}) assert_close(actual, expected, atol=1e-3, rtol=1e-3) diff --git a/test/test_samplers.py b/test/test_samplers.py index f86f6c214..71873af65 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -14,7 +14,8 @@ from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND from funsor.domains import Bint, Real, Reals from funsor.integrate import Integrate -from funsor.montecarlo import monte_carlo_interpretation +from funsor.interpreter import interpretation +from funsor.montecarlo import MonteCarlo from funsor.tensor import Tensor, align_tensors from funsor.terms import Variable from funsor.testing import assert_close, id_from_inputs, randn, random_gaussian, random_tensor, xfail_if_not_implemented @@ -341,7 +342,7 @@ def test_lognormal_distribution(moment): log_measure = dist.LogNormal(loc, scale)(value='x') probe = Variable('x', Real) ** moment - with monte_carlo_interpretation(particle=Bint[num_samples]): + with interpretation(MonteCarlo(particle=Bint[num_samples])): with xfail_if_not_implemented(): actual = Integrate(log_measure, probe, frozenset(['x']))