Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more conjugates #113

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 176 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
from aesara.tensor.random.basic import (
BernoulliRV,
BinomialRV,
NegBinomialRV,
PoissonRV,
UniformRV,
)
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var
Expand Down Expand Up @@ -268,13 +274,182 @@ def local_beta_negative_binomial_posterior(fgraph, node):
return rv_var.owner.outputs


def beta_bernoulli_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a negative binomial observation model.

.. math::

\frac{
Y \sim \operatorname{P(x=1)}= p, \quad
p \sim \operatorname{Beta}\left(\alpha, \beta\right)
}{
\left(p \mid Y=y\right) \sim \operatorname{Beta}\left(\alpha + \sum^{n}_{i=1} y_i, \beta + n - \sum^{n}_{i=1} y_i,\right)
}


Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.

"""
# beta-negative_binomial observation model
alpha_lv, beta_lv = var(), var()
p_rng_lv = var()
p_size_lv = var()
p_type_idx_lv = var()
p_et = etuple(
etuplize(at.random.beta), p_rng_lv, p_size_lv, p_type_idx_lv, alpha_lv, beta_lv
)
Y_et = etuple(etuplize(at.random.bernoulli), var(), var(), var(), p_et)

new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
new_beta_et = etuple(etuplize(at.add), beta_lv, 1, -observed_val)

p_posterior_et = etuple(
etuplize(at.random.beta),
new_alpha_et,
new_beta_et,
rng=p_rng_lv,
size=p_size_lv,
dtype=p_type_idx_lv,
)

return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, p_posterior_et),
)


@node_rewriter([BernoulliRV])
def local_beta_bernoulli_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_bernoulli_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, beta_bernoulli_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

beta_rv = rv_et[-1].evaled_obj
beta_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append(
("local_beta_bernoulli_posterior", beta_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a pareto prior with a uniform with 0 as the lower bound observation model.

.. math::
Y \sim \operatorname{Uniform}\left(0, \theta\right)
\theta \sim \operatorname{pareto}\(max(x), k)



Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.

"""
# beta-negative_binomial observation model
x_lv, k_lv = var(), var()
theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_et = etuple(
etuplize(at.random.pareto),
theta_rng_lv,
theta_size_lv,
theta_type_idx_lv,
k_lv,
x_lv,
)
Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), var(), theta_et)

new_x_et = etuple(at.math.max, observed_val)
new_k_et = etuple(etuplize(at.add), k_lv, 1)

theta_posterior_et = etuple(
etuplize(at.random.pareto),
new_k_et,
new_x_et,
rng=theta_rng_lv,
size=theta_size_lv,
dtype=theta_type_idx_lv,
)
return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, theta_posterior_et),
)


@node_rewriter([UniformRV])
def local_uniform_pareto_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_negative_binomial_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

pareto_rv = rv_et[-1].evaled_obj
pareto_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append(
("local_uniform_pareto_posterior", pareto_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
conjugates_db.register(
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
)
conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic")


sampler_finder_db.register(
Expand Down
122 changes: 122 additions & 0 deletions tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import pytest
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random import RandomStream
from etuples import etuple, etuplize
from kanren import run
from unification import var

from aemcmc.conjugates import (
beta_bernoulli_conjugateo,
beta_binomial_conjugateo,
beta_negative_binomial_conjugateo,
gamma_poisson_conjugateo,
uniform_pareto_conjugateo,
)


Expand Down Expand Up @@ -157,3 +160,122 @@ def test_beta_negative_binomial_conjugate_expand():
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_beta_bernoulli_conjugate_contract():
"""Produce the closed-form posterior for the binomial observation model with
a beta prior.

"""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
p_rv = srng.beta(alpha_tt, beta_tt, name="p")

Y_rv = srng.bernoulli(p_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, beta_bernoulli_conjugateo(y_vv, Y_rv, q_lv))
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.beta))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((alpha_tt, beta_tt, y_vv), posterior)
assert sample_fn(1.0, 1.0, 1) == pytest.approx(1.0, abs=0.3) # only successes
assert sample_fn(1.0, 1.0, 0) == pytest.approx(0.0, abs=0.3) # no success


@pytest.mark.xfail(
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
)
def test_beta_bernoulli_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
y_vv = at.iscalar("y")
n_tt = at.iscalar("n")
Y_rv = srng.beta(alpha_tt + y_vv, beta_tt + n_tt - y_vv)

e_lv = var()
(expanded_expr,) = run(1, e_lv, beta_bernoulli_conjugateo(e_lv, y_vv, Y_rv))
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_uniform_pareto_conjugate_contract():
"""Produce the closed-form posterior for the uniform observation model with
a pareto prior.

"""
srng = RandomStream(0)

xm_tt = at.scalar("xm")
k_tt = at.scalar("k")
theta_rv = srng.pareto(k_tt, xm_tt, name="theta")

# zero = at.iscalar("zero")
Y_rv = srng.uniform(0, theta_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv))
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.pareto))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((xm_tt, k_tt, y_vv), posterior)
assert sample_fn(1.0, 1000, 1) == pytest.approx(1.0, abs=0.01) # k = 1000
assert sample_fn(1.0, 1, 0) == pytest.approx(0.0, abs=0.01) # all zeros


def test_uniform_pareto_binomial_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

k_tt = at.scalar("k")
y_vv = at.iscalar("y")
n_tt = at.scalar("n")

Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the test graph looks like this:

>>> aesara.dprint(Y_rv)
pareto_rv{0, (0, 0), floatX, False}.1 [id A]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FC4162BFF20>) [id B]
 |TensorConstant{[]} [id C]
 |TensorConstant{11} [id D]
 |MaxAndArgmax{axis=()}.0 [id E] 'max'
 | |y [id F]
 |Elemwise{add,no_inplace} [id G]
   |k [id H]
   |n [id I]

and that MaxAndArgmax Op isn't same as the at.math.max used in the etuple graph. at.math.max is a function that constructs a MaxAndArgmax Op and uses it to further construct a graph for the max of its argument. In other words, we need an etuple form/"pattern" that matches the types of graphs output by the helper function at.math.max.

Often the easiest way to find etuple forms for the graphs constructed by helper functions is to etuplize said graphs and spot their generalities.
For example:

>>> from etuples import etuplize
>>> etuplize(at.math.max(at.vector("x")))
e(e(aesara.tensor.math.MaxAndArgmax, (0,)), x)
>>> etuplize(at.math.max(at.matrix("x")))
e(e(aesara.tensor.math.MaxAndArgmax, (0, 1)), x)

As we can see, the axis property in the MaxAndArgmax Op will change according to the dimensions of the input (i.e. it computes the max across all dimensions), so we don't want to use a very specific value for the matching form. Instead, we can use another logic variable in place of those values.

Here's a general testing setup for that part of the problem:

import aesara
import aesara.tensor as at

from etuples import etuplize, etuple


srng = at.random.RandomStream(0)

k_tt = at.scalar("k")
y_vv = at.iscalar("y")
n_tt = at.scalar("n")

Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)

# This is what we need to match/unify:
etuplize(Y_rv)
# e(
#     e(aesara.tensor.random.basic.ParetoRV, 'pareto', 0, (0, 0), 'floatX', False),
#     RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FA1F9B3D9E0>),
#     TensorConstant{[]},
#     TensorConstant{11},
#     e(e(aesara.tensor.math.MaxAndArgmax, ()), y),
#     e(
#         e(
#             aesara.tensor.elemwise.Elemwise,
#             <aesara.scalar.basic.Add at 0x7fa1fd3823d0>,
#             <frozendict {}>),
#         k,
#         n))

from unification import var
from kanren import run, eq
from aesara.tensor.math import MaxAndArgmax


observed_val = var()
axis_lv = var()
new_x_et = etuple(etuple(MaxAndArgmax, axis_lv), observed_val)

k_lv, n_lv = var(), var()
new_k_et = etuple(etuplize(at.add), k_lv, n_lv)

theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_posterior_et = etuple(
    etuplize(at.random.pareto),
    theta_rng_lv,
    theta_size_lv,
    theta_type_idx_lv,
    new_x_et,
    new_k_et,
)


run(0, (new_x_et, new_k_et), eq(Y_rv, theta_posterior_et))
# ((e(e(aesara.tensor.math.MaxAndArgmax, ()), y),
#   e(
#       e(
#           aesara.tensor.elemwise.Elemwise,
#           <aesara.scalar.basic.Add at 0x7fa1fd3823d0>,
#           <frozendict {}>),
#       k,
#       n)),)

etuplize(Y_rv)

# e_lv = var()
# (expanded_expr,) = run(1, e_lv, uniform_pareto_conjugateo(e_lv, y_vv, Y_rv))
# expanded = eval_if_etuple(expanded_expr)

# assert isinstance(expanded.owner.op, type(at.random.pareto))
from aesara.tensor.math import MaxAndArgmax
from kanren import eq, run
from unification import var

observed_val = var()
axis_lv = var()
new_x_et = etuple(etuple(MaxAndArgmax, axis_lv), observed_val)

k_lv, n_lv = var(), var()
new_k_et = etuple(etuplize(at.add), k_lv, n_lv)

theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_posterior_et = etuple(
etuplize(at.random.pareto),
theta_rng_lv,
theta_size_lv,
theta_type_idx_lv,
new_x_et,
new_k_et,
)

run(0, (new_x_et, new_k_et), eq(Y_rv, theta_posterior_et))