diff --git a/aemcmc/conjugates.py b/aemcmc/conjugates.py index d49d7dd..4fcc9ae 100644 --- a/aemcmc/conjugates.py +++ b/aemcmc/conjugates.py @@ -2,7 +2,7 @@ from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.rewriting.db import LocalGroupDB from aesara.graph.rewriting.unify import eval_if_etuple -from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV +from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV, UniformRV from etuples import etuple, etuplize from kanren import eq, lall, run from unification import var @@ -268,6 +268,89 @@ def local_beta_negative_binomial_posterior(fgraph, node): return rv_var.owner.outputs +def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr): + r"""Produce a goal that represents the application of Bayes theorem + for a pareto prior with a uniform with 0 as the lower bound observation model. + + .. math:: + Y \sim \operatorname{Uniform}\left(0, \theta\right) + + + + Parameters + ---------- + observed_val + The observed value. + observed_rv_expr + An expression that represents the observed variable. + posterior_exp + An expression that represents the posterior distribution of the latent + variable. + + """ + # beta-negative_binomial observation model + x_lv, k_lv = var(), var() + theta_rng_lv = var() + theta_size_lv = var() + theta_type_idx_lv = var() + theta_et = etuple( + etuplize(at.random.pareto), + theta_rng_lv, + theta_size_lv, + theta_type_idx_lv, + k_lv, + x_lv, + ) + Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), var(), theta_et) + + new_x_et = etuple(at.math.max, observed_val) + new_k_et = etuple(etuplize(at.add), k_lv, 1) + + theta_posterior_et = etuple( + etuplize(at.random.pareto), + new_k_et, + new_x_et, + rng=theta_rng_lv, + size=theta_size_lv, + dtype=theta_type_idx_lv, + ) + return lall( + eq(observed_rv_expr, Y_et), + eq(posterior_expr, theta_posterior_et), + ) + + +@node_rewriter([UniformRV]) +def local_uniform_pareto_posterior(fgraph, node): + sampler_mappings = getattr(fgraph, "sampler_mappings", None) + + rv_var = node.outputs[1] + key = ("local_beta_negative_binomial_posterior", rv_var) + + if sampler_mappings is None or key in sampler_mappings.rvs_seen: + return None # pragma: no cover + + q = var() + + rv_et = etuplize(rv_var) + + res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q)) + res = next(res, None) + + if res is None: + return None # pragma: no cover + + pareto_rv = rv_et[-1].evaled_obj + pareto_posterior = eval_if_etuple(res) + + sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append( + ("local_uniform_pareto_posterior", pareto_posterior, None) + ) + sampler_mappings.rvs_seen.add(key) + + return rv_var.owner.outputs + + conjugates_db = LocalGroupDB(apply_all_rewrites=True) conjugates_db.name = "conjugates_db" conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic") @@ -275,6 +358,7 @@ def local_beta_negative_binomial_posterior(fgraph, node): conjugates_db.register( "negative_binomial", local_beta_negative_binomial_posterior, "basic" ) +conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic") sampler_finder_db.register( diff --git a/tests/test_conjugates.py b/tests/test_conjugates.py index a86fa06..0936dc9 100644 --- a/tests/test_conjugates.py +++ b/tests/test_conjugates.py @@ -10,6 +10,7 @@ beta_binomial_conjugateo, beta_negative_binomial_conjugateo, gamma_poisson_conjugateo, + uniform_pareto_conjugateo, ) @@ -157,3 +158,49 @@ def test_beta_negative_binomial_conjugate_expand(): expanded = eval_if_etuple(expanded_expr) assert isinstance(expanded.owner.op, type(at.random.beta)) + + +def test_uniform_pareto_conjugate_contract(): + """Produce the closed-form posterior for the uniform observation model with + a pareto prior. + + """ + srng = RandomStream(0) + + xm_tt = at.scalar("xm") + k_tt = at.scalar("k") + theta_rv = srng.pareto(k_tt, xm_tt, name="theta") + + # zero = at.iscalar("zero") + Y_rv = srng.uniform(0, theta_rv) + y_vv = Y_rv.clone() + y_vv.tag.name = "y" + + q_lv = var() + (posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv)) + posterior = eval_if_etuple(posterior_expr) + + assert isinstance(posterior.owner.op, type(at.random.pareto)) + + # Build the sampling function and check the results on limiting cases. + sample_fn = aesara.function((xm_tt, k_tt, y_vv), posterior) + assert sample_fn(1.0, 1000, 1) == pytest.approx(1.0, abs=0.01) # k = 1000 + assert sample_fn(1.0, 1, 0) == pytest.approx(0.0, abs=0.01) # all zeros + + +def test_uniform_pareto_binomial_conjugate_expand(): + """Expand a contracted beta-binomial observation model.""" + + srng = RandomStream(0) + + k_tt = at.scalar("k") + y_vv = at.iscalar("y") + n_tt = at.scalar("n") + + Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt) + + e_lv = var() + (expanded_expr,) = run(1, e_lv, uniform_pareto_conjugateo(e_lv, y_vv, Y_rv)) + expanded = eval_if_etuple(expanded_expr) + + assert isinstance(expanded.owner.op, type(at.random.pareto))