Skip to content

Commit

Permalink
Add Uniform pareto conjugates
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing Xie committed Feb 24, 2023
1 parent 64b0e50 commit 089e8b3
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 1 deletion.
88 changes: 87 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV, UniformRV
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var
Expand Down Expand Up @@ -268,13 +268,99 @@ def local_beta_negative_binomial_posterior(fgraph, node):
return rv_var.owner.outputs


def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a pareto prior with a uniform with 0 as the lower bound observation model.
.. math::
Y \sim \operatorname{Uniform}\left(0, \theta\right)
Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.
"""
# beta-negative_binomial observation model
x_lv, k_lv = var(), var()
theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_et = etuple(
etuplize(at.random.pareto),
theta_rng_lv,
theta_size_lv,
theta_type_idx_lv,
k_lv,
x_lv,
)
Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), 1, theta_et)

# new_x_et = at.max(observed_val)
new_x_et = at.max(observed_val, x_lv)
new_k_et = etuple(etuplize(at.add), k_lv, 1)

theta_posterior_et = etuple(
etuplize(at.random.pareto),
new_k_et,
new_x_et,
rng=theta_rng_lv,
size=theta_size_lv,
dtype=theta_type_idx_lv,
)

return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, theta_posterior_et),
)


@node_rewriter([UniformRV])
def local_uniform_pareto_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_negative_binomial_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

pareto_rv = rv_et[-1].evaled_obj
pareto_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append(
("local_uniform_pareto_posterior", pareto_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
conjugates_db.register(
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
)
conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic")


sampler_finder_db.register(
Expand Down
23 changes: 23 additions & 0 deletions tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
beta_binomial_conjugateo,
beta_negative_binomial_conjugateo,
gamma_poisson_conjugateo,
uniform_pareto_conjugateo,
)


Expand Down Expand Up @@ -157,3 +158,25 @@ def test_beta_negative_binomial_conjugate_expand():
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_uniform_pareto_conjugate_contract():
"""Produce the closed-form posterior for the uniform observation model with
a pareto prior.
"""
srng = RandomStream(0)

xm_tt = at.scalar("xm")
k_tt = at.scalar("k")
theta_rv = srng.pareto(k_tt, xm_tt, name="theta")

Y_rv = srng.uniform(0, theta_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv))
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.pareto))

0 comments on commit 089e8b3

Please sign in to comment.