From bb8cc024de71d97701d6e89152470be0307aa1e6 Mon Sep 17 00:00:00 2001 From: Larry Dong Date: Sat, 21 Oct 2023 13:40:26 -0400 Subject: [PATCH] Add closed-form posterior for gamma-exponential observation model --- aemcmc/conjugates.py | 87 +++++++++++++++++++++++++++++++++++++++- tests/test_conjugates.py | 37 +++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) diff --git a/aemcmc/conjugates.py b/aemcmc/conjugates.py index 2c2617f..8766937 100644 --- a/aemcmc/conjugates.py +++ b/aemcmc/conjugates.py @@ -5,7 +5,12 @@ from aesara.graph.rewriting.basic import in2out from aesara.graph.rewriting.db import LocalGroupDB from aesara.graph.rewriting.unify import eval_if_etuple -from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV +from aesara.tensor.random.basic import ( + BinomialRV, + ExponentialRV, + NegBinomialRV, + PoissonRV, +) from etuples import etuple, etuplize from kanren import eq, lall, run from unification import var @@ -238,6 +243,85 @@ def local_beta_negative_binomial_posterior(fgraph, node, srng): return [(beta_rv, beta_posterior, None)] +def gamma_exponential_conjugateo( + srng: "RandomStream", observed_rv_expr, posterior_expr +): + r""" + Relation for the conjugate posterior of a gamma prior with an exponential observation model. + + .. math:: + + \frac{ + Y \sim \operatorname{Exp}\left(\lambda\right), \quad + \lambda \sim \operatorname{Gamma}\left(\alpha, \beta\right) + }{ + \left(\lambda|Y=y\right) \sim \operatorname{Gamma}\left(\alpha+1, \beta+y\right) + } + + Parameters + ---------- + srng + The `RandomStream` used to generate the posterior variates. + observed_rv_expr + An expression that represents the observed variable. + posterior_exp + An expression that represents the posterior distribution of the latent + variable. + + """ + # Gamma-exponential observation model + alpha_lv, beta_lv = var(), var() + lam_rng_lv = var() + lam_size_lv = var() + lam_type_idx_lv = var() + lam_et = etuple( + etuplize(at.random.gamma), + lam_rng_lv, + lam_size_lv, + lam_type_idx_lv, + alpha_lv, + beta_lv, + ) + Y_et = etuple(etuplize(at.random.exponential), var(), var(), var(), lam_et) + + # Posterior distribution for lambda + new_alpha_et = etuple(etuplize(at.add), alpha_lv, 1) + new_beta_et = etuple(etuplize(at.add), beta_lv, observed_rv_expr) + + lam_posterior_et = etuple( + partial(srng.gen, at.random.gamma), + new_alpha_et, + new_beta_et, + size=lam_size_lv, + dtype=lam_type_idx_lv, + ) + + return lall( + eq(observed_rv_expr, Y_et), + eq(posterior_expr, lam_posterior_et), + ) + + +@sampler_finder([ExponentialRV]) +def local_gamma_exponential_posterior(fgraph, node, srng): + rv_var = node.outputs[1] + + q = var() + + rv_et = etuplize(rv_var) + + res = run(None, q, partial(beta_negative_binomial_conjugateo, srng)(rv_et, q)) + res = next(res, None) + + if res is None: + return None # pragma: no cover + + lam_rv = rv_et[-1].evaled_obj + lam_posterior = eval_if_etuple(res) + + return [(lam_rv, lam_posterior, None)] + + conjugates_db = LocalGroupDB(apply_all_rewrites=True) conjugates_db.name = "conjugates_db" conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic") @@ -245,6 +329,7 @@ def local_beta_negative_binomial_posterior(fgraph, node, srng): conjugates_db.register( "negative_binomial", local_beta_negative_binomial_posterior, "basic" ) +conjugates_db.register("gamma_exponential", local_gamma_exponential_posterior, "basic") sampler_finder_db.register( diff --git a/tests/test_conjugates.py b/tests/test_conjugates.py index 26f0ef0..3c299b1 100644 --- a/tests/test_conjugates.py +++ b/tests/test_conjugates.py @@ -9,6 +9,7 @@ from aemcmc.conjugates import ( beta_binomial_conjugateo, beta_negative_binomial_conjugateo, + gamma_exponential_conjugateo, gamma_poisson_conjugateo, ) @@ -142,3 +143,39 @@ def test_beta_negative_binomial_conjugate_expand(): expanded = expanded_expr assert isinstance(expanded.owner.op, type(at.random.beta)) + + +def test_gamma_exponential_conjugate_contract(): + """Produce the closed-form posterior for the exponential observation model with a gamma prior.""" + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_tt = at.scalar("beta") + lam_rv = srng.gamma(alpha_tt, beta_tt) + Y_rv = srng.exponential(lam_rv) + + q_lv = var() + (posterior_expr,) = run(1, q_lv, gamma_exponential_conjugateo(srng, Y_rv, q_lv)) + posterior = posterior_expr.evaled_obj + + assert isinstance(posterior.owner.op, type(at.random.gamma)) + + +@pytest.mark.xfail( + reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error" +) +def test_gamma_exponential_conjugate_expand(): + """Expand a contracted gamma-exponential observation model.""" + + srng = RandomStream(0) + + alpha_tt = at.scalar("alpha") + beta_tt = at.scalar("beta") + y_vv = at.iscalar("y") + Y_rv = srng.gamma(alpha_tt + y_vv, beta_tt + 1) + + e_lv = var() + (expanded_expr,) = run(1, e_lv, gamma_exponential_conjugateo(srng, e_lv, Y_rv)) + expanded = expanded_expr.evaled_obj + + assert isinstance(expanded.owner.op, type(at.random.gamma))