Skip to content

Commit

Permalink
Add in NB beta conjugates and its tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing Xie authored and rlouf committed Feb 16, 2023
1 parent 50b7068 commit fede3b1
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 2 deletions.
93 changes: 92 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV, PoissonRV
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var
Expand Down Expand Up @@ -181,10 +181,101 @@ def local_beta_binomial_posterior(fgraph, node):
return rv_var.owner.outputs


def beta_negative_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a negative binomial observation model.
.. math::
\frac{
Y \sim \operatorname{NB}\left(k, p\right), \quad
p \sim \operatorname{Beta}\left(\alpha, \beta\right)
}{
\left(p|Y=y\right) \sim \operatorname{Beta}\left(\alpha+\sum^{n}_{i=1} y_i, \beta+kN\right)
}
, k is the number of successes before experiment ended
Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.
"""
# beta-negative_binomial observation model
alpha_lv, beta_lv = var(), var()
p_rng_lv = var()
p_size_lv = var()
p_type_idx_lv = var()
p_et = etuple(
etuplize(at.random.beta), p_rng_lv, p_size_lv, p_type_idx_lv, alpha_lv, beta_lv
)
n_lv = var() # success
Y_et = etuple(
etuplize(at.random.negative_binomial), var(), var(), var(), n_lv, p_et
)

new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
new_beta_et = etuple(etuplize(at.add), beta_lv, n_lv)
p_posterior_et = etuple(
etuplize(at.random.beta),
new_alpha_et,
new_beta_et,
rng=p_rng_lv,
size=p_size_lv,
dtype=p_type_idx_lv,
)

return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, p_posterior_et),
)


@node_rewriter([NegBinomialRV])
def local_beta_negative_binomial_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_negative_binomial_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, beta_negative_binomial_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

beta_rv = rv_et[-1].evaled_obj
beta_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append(
("local_beta_negative_binomial_posterior", beta_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
conjugates_db.register(
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
)


sampler_finder_db.register(
"conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic"
Expand Down
59 changes: 58 additions & 1 deletion tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from kanren import run
from unification import var

from aemcmc.conjugates import beta_binomial_conjugateo, gamma_poisson_conjugateo
from aemcmc.conjugates import (
beta_binomial_conjugateo,
beta_negative_binomial_conjugateo,
gamma_poisson_conjugateo,
)


def test_gamma_poisson_conjugate_contract():
Expand Down Expand Up @@ -101,3 +105,56 @@ def test_beta_binomial_conjugate_expand():
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_beta_negative_binomial_conjugate_contract():
"""Produce the closed-form posterior for the binomial observation model with
a beta prior.
"""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
p_rv = srng.beta(alpha_tt, beta_tt, name="p")

n_tt = at.iscalar("n")
Y_rv = srng.negative_binomial(n_tt, p_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(
1, q_lv, beta_negative_binomial_conjugateo(y_vv, Y_rv, q_lv)
)
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.beta))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((alpha_tt, beta_tt, y_vv, n_tt), posterior)
assert sample_fn(1.0, 1.0, 1000, 0) == pytest.approx(
1.0, abs=0.01
) # only successes
assert sample_fn(1.0, 1.0, 0, 1000) == pytest.approx(0.0, abs=0.01) # no success


@pytest.mark.xfail(
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
)
def test_beta_negative_binomial_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
y_vv = at.iscalar("y")
n_tt = at.iscalar("n")
Y_rv = srng.beta(alpha_tt + y_vv, beta_tt + n_tt)

e_lv = var()
(expanded_expr,) = run(1, e_lv, beta_negative_binomial_conjugateo(e_lv, y_vv, Y_rv))
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))

0 comments on commit fede3b1

Please sign in to comment.