diff --git a/aemcmc/conjugates.py b/aemcmc/conjugates.py index d49d7dd..7d6fc40 100644 --- a/aemcmc/conjugates.py +++ b/aemcmc/conjugates.py @@ -2,7 +2,13 @@ from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.rewriting.db import LocalGroupDB from aesara.graph.rewriting.unify import eval_if_etuple -from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV +from aesara.tensor.random.basic import ( + BernoulliRV, + BinomialRV, + NegBinomialRV, + PoissonRV, + UniformRV, +) from etuples import etuple, etuplize from kanren import eq, lall, run from unification import var @@ -268,6 +274,174 @@ def local_beta_negative_binomial_posterior(fgraph, node): return rv_var.owner.outputs +def beta_bernoulli_conjugateo(observed_val, observed_rv_expr, posterior_expr): + r"""Produce a goal that represents the application of Bayes theorem + for a beta prior with a negative binomial observation model. + + .. math:: + + \frac{ + Y \sim \operatorname{P(x=1)}= p, \quad + p \sim \operatorname{Beta}\left(\alpha, \beta\right) + }{ + \left(p \mid Y=y\right) \sim \operatorname{Beta}\left(\alpha + \sum^{n}_{i=1} y_i, \beta + n - \sum^{n}_{i=1} y_i,\right) + } + + + Parameters + ---------- + observed_val + The observed value. + observed_rv_expr + An expression that represents the observed variable. + posterior_exp + An expression that represents the posterior distribution of the latent + variable. + + """ + # beta-negative_binomial observation model + alpha_lv, beta_lv = var(), var() + p_rng_lv = var() + p_size_lv = var() + p_type_idx_lv = var() + p_et = etuple( + etuplize(at.random.beta), p_rng_lv, p_size_lv, p_type_idx_lv, alpha_lv, beta_lv + ) + Y_et = etuple(etuplize(at.random.bernoulli), var(), var(), var(), p_et) + + new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val) + new_beta_et = etuple(etuplize(at.add), beta_lv, 1, -observed_val) + + p_posterior_et = etuple( + etuplize(at.random.beta), + new_alpha_et, + new_beta_et, + rng=p_rng_lv, + size=p_size_lv, + dtype=p_type_idx_lv, + ) + + return lall( + eq(observed_rv_expr, Y_et), + eq(posterior_expr, p_posterior_et), + ) + + +@node_rewriter([BernoulliRV]) +def local_beta_bernoulli_posterior(fgraph, node): + sampler_mappings = getattr(fgraph, "sampler_mappings", None) + + rv_var = node.outputs[1] + key = ("local_beta_bernoulli_posterior", rv_var) + + if sampler_mappings is None or key in sampler_mappings.rvs_seen: + return None # pragma: no cover + + q = var() + + rv_et = etuplize(rv_var) + + res = run(None, q, beta_bernoulli_conjugateo(rv_var, rv_et, q)) + res = next(res, None) + + if res is None: + return None # pragma: no cover + + beta_rv = rv_et[-1].evaled_obj + beta_posterior = eval_if_etuple(res) + + sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append( + ("local_beta_bernoulli_posterior", beta_posterior, None) + ) + sampler_mappings.rvs_seen.add(key) + + return rv_var.owner.outputs + + +def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr): + r"""Produce a goal that represents the application of Bayes theorem + for a pareto prior with a uniform with 0 as the lower bound observation model. + + .. math:: + Y \sim \operatorname{Uniform}\left(0, \theta\right) + \theta \sim \operatorname{pareto}\(max(x), k) + + + + Parameters + ---------- + observed_val + The observed value. + observed_rv_expr + An expression that represents the observed variable. + posterior_exp + An expression that represents the posterior distribution of the latent + variable. + + """ + # beta-negative_binomial observation model + x_lv, k_lv = var(), var() + theta_rng_lv = var() + theta_size_lv = var() + theta_type_idx_lv = var() + theta_et = etuple( + etuplize(at.random.pareto), + theta_rng_lv, + theta_size_lv, + theta_type_idx_lv, + k_lv, + x_lv, + ) + Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), var(), theta_et) + + new_x_et = etuple(at.math.max, observed_val) + new_k_et = etuple(etuplize(at.add), k_lv, 1) + + theta_posterior_et = etuple( + etuplize(at.random.pareto), + new_k_et, + new_x_et, + rng=theta_rng_lv, + size=theta_size_lv, + dtype=theta_type_idx_lv, + ) + return lall( + eq(observed_rv_expr, Y_et), + eq(posterior_expr, theta_posterior_et), + ) + + +@node_rewriter([UniformRV]) +def local_uniform_pareto_posterior(fgraph, node): + sampler_mappings = getattr(fgraph, "sampler_mappings", None) + + rv_var = node.outputs[1] + key = ("local_beta_negative_binomial_posterior", rv_var) + + if sampler_mappings is None or key in sampler_mappings.rvs_seen: + return None # pragma: no cover + + q = var() + + rv_et = etuplize(rv_var) + + res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q)) + res = next(res, None) + + if res is None: + return None # pragma: no cover + + pareto_rv = rv_et[-1].evaled_obj + pareto_posterior = eval_if_etuple(res) + + sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append( + ("local_uniform_pareto_posterior", pareto_posterior, None) + ) + sampler_mappings.rvs_seen.add(key) + + return rv_var.owner.outputs + + conjugates_db = LocalGroupDB(apply_all_rewrites=True) conjugates_db.name = "conjugates_db" conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic") @@ -275,6 +449,7 @@ def local_beta_negative_binomial_posterior(fgraph, node): conjugates_db.register( "negative_binomial", local_beta_negative_binomial_posterior, "basic" ) +conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic") sampler_finder_db.register( diff --git a/tests/test_conjugates.py b/tests/test_conjugates.py index a86fa06..6f8c431 100644 --- a/tests/test_conjugates.py +++ b/tests/test_conjugates.py @@ -3,13 +3,16 @@ import pytest from aesara.graph.rewriting.unify import eval_if_etuple from aesara.tensor.random import RandomStream +from etuples import etuple, etuplize from kanren import run from unification import var from aemcmc.conjugates import ( + beta_bernoulli_conjugateo, beta_binomial_conjugateo, beta_negative_binomial_conjugateo, gamma_poisson_conjugateo, + uniform_pareto_conjugateo, ) @@ -157,3 +160,122 @@ def test_beta_negative_binomial_conjugate_expand(): expanded = eval_if_etuple(expanded_expr) assert isinstance(expanded.owner.op, type(at.random.beta)) + + +def test_beta_bernoulli_conjugate_contract(): + """Produce the closed-form posterior for the binomial observation model with + a beta prior. + + """ + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_tt = at.scalar("beta") + p_rv = srng.beta(alpha_tt, beta_tt, name="p") + + Y_rv = srng.bernoulli(p_rv) + y_vv = Y_rv.clone() + y_vv.tag.name = "y" + + q_lv = var() + (posterior_expr,) = run(1, q_lv, beta_bernoulli_conjugateo(y_vv, Y_rv, q_lv)) + posterior = eval_if_etuple(posterior_expr) + + assert isinstance(posterior.owner.op, type(at.random.beta)) + + # Build the sampling function and check the results on limiting cases. + sample_fn = aesara.function((alpha_tt, beta_tt, y_vv), posterior) + assert sample_fn(1.0, 1.0, 1) == pytest.approx(1.0, abs=0.3) # only successes + assert sample_fn(1.0, 1.0, 0) == pytest.approx(0.0, abs=0.3) # no success + + +@pytest.mark.xfail( + reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error" +) +def test_beta_bernoulli_conjugate_expand(): + """Expand a contracted beta-binomial observation model.""" + + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_tt = at.scalar("beta") + y_vv = at.iscalar("y") + n_tt = at.iscalar("n") + Y_rv = srng.beta(alpha_tt + y_vv, beta_tt + n_tt - y_vv) + + e_lv = var() + (expanded_expr,) = run(1, e_lv, beta_bernoulli_conjugateo(e_lv, y_vv, Y_rv)) + expanded = eval_if_etuple(expanded_expr) + + assert isinstance(expanded.owner.op, type(at.random.beta)) + + +def test_uniform_pareto_conjugate_contract(): + """Produce the closed-form posterior for the uniform observation model with + a pareto prior. + + """ + srng = RandomStream(0) + + xm_tt = at.scalar("xm") + k_tt = at.scalar("k") + theta_rv = srng.pareto(k_tt, xm_tt, name="theta") + + # zero = at.iscalar("zero") + Y_rv = srng.uniform(0, theta_rv) + y_vv = Y_rv.clone() + y_vv.tag.name = "y" + + q_lv = var() + (posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv)) + posterior = eval_if_etuple(posterior_expr) + + assert isinstance(posterior.owner.op, type(at.random.pareto)) + + # Build the sampling function and check the results on limiting cases. + sample_fn = aesara.function((xm_tt, k_tt, y_vv), posterior) + assert sample_fn(1.0, 1000, 1) == pytest.approx(1.0, abs=0.01) # k = 1000 + assert sample_fn(1.0, 1, 0) == pytest.approx(0.0, abs=0.01) # all zeros + + +def test_uniform_pareto_binomial_conjugate_expand(): + """Expand a contracted beta-binomial observation model.""" + + srng = RandomStream(0) + + k_tt = at.scalar("k") + y_vv = at.iscalar("y") + n_tt = at.scalar("n") + + Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt) + etuplize(Y_rv) + + # e_lv = var() + # (expanded_expr,) = run(1, e_lv, uniform_pareto_conjugateo(e_lv, y_vv, Y_rv)) + # expanded = eval_if_etuple(expanded_expr) + + # assert isinstance(expanded.owner.op, type(at.random.pareto)) + from aesara.tensor.math import MaxAndArgmax + from kanren import eq, run + from unification import var + + observed_val = var() + axis_lv = var() + new_x_et = etuple(etuple(MaxAndArgmax, axis_lv), observed_val) + + k_lv, n_lv = var(), var() + new_k_et = etuple(etuplize(at.add), k_lv, n_lv) + + theta_rng_lv = var() + theta_size_lv = var() + theta_type_idx_lv = var() + theta_posterior_et = etuple( + etuplize(at.random.pareto), + theta_rng_lv, + theta_size_lv, + theta_type_idx_lv, + new_x_et, + new_k_et, + ) + + run(0, (new_x_et, new_k_et), eq(Y_rv, theta_posterior_et))