Skip to content

Commit

Permalink
Add a function to generate prior samples
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Mar 17, 2023
1 parent ab429b2 commit d2fdc5f
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aemcmc/sample/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .prior import sample_prior

__all__ = ["sample_prior"]
86 changes: 86 additions & 0 deletions aemcmc/sample/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import TYPE_CHECKING, Dict

import aesara
import aesara.tensor as at
import aesara.tensor.random as ar
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import ancestors
from aesara.tensor.random.type import RandomType

if TYPE_CHECKING:
from aesara.graph.basic import Variable


def get_rv_updates(
srng: ar.RandomStream, *rvs: at.TensorVariable
) -> Dict[SharedVariable, "Variable"]:
r"""Get the updates needed to update RNG objects during sampling of `rvs`.
A search is performed over `rvs` for `SharedVariable`\s with default
updates and the updates stored in `srng`.
Parameters
----------
srng:
`RandomStream` instance with which the model was defined.
rvs:
The random variables whose prior distribution we want to sample.
Returns
-------
A dict containing the updates needed to sample from the models given by
`rvs`.
"""
# TODO: It's kind of weird that this is an alist-like data structure; we
# should revisit this in `RandomStream`
srng_updates = dict(srng.state_updates)
rv_updates = {}

for var in ancestors(rvs):
if not isinstance(var, SharedVariable) and not isinstance(var.type, RandomType):
continue

# TODO: Consider making sure the updates correspond to "in-place"
# updates of the RNGs for relevant `RandomVariable`s?
# More generally, a function like this could be used to determine the
# consistency of `RandomVariable` updates in general (e.g. find
# bad/disassociated updates).
srng_update = srng_updates.get(var)

if var.default_update:
if srng_update:
assert srng_update == var.default_update

# We prefer the default update (for no particular reason)
rv_updates[var] = var.default_update
elif srng_update:
rv_updates[var] = srng_update

return rv_updates


def sample_prior(
srng: ar.RandomStream, num_samples: at.TensorVariable, *rvs: at.TensorVariable
) -> at.TensorVariable:
"""Sample from a model's prior distributions.
Parameters
----------
srng:
`RandomStream` instance with which the model was defined.
num_samples:
The number of prior samples to generate.
rvs:
The random variables whose prior distribution we want to sample.
"""

rv_updates = get_rv_updates(srng, *rvs)

def step_fn():
return rvs, rv_updates

samples, updates = aesara.scan(step_fn, n_steps=num_samples)

return samples, updates
31 changes: 31 additions & 0 deletions tests/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import aesara
import aesara.tensor as at
import numpy as np
from aesara.compile.sharedvalue import SharedVariable

from aemcmc.sample import sample_prior


def test_sample_prior():
srng = at.random.RandomStream(123)

mu_rv = srng.normal(0, 1, name="mu")
Y_rv = srng.normal(mu_rv, 1.0, name="Y")
Z_rv = srng.gamma(0.5, 0.5, name="Z")

samples, updates = sample_prior(srng, 10, Y_rv)
fn = aesara.function([], samples, updates=updates)

# Make sure that `Z_rv` doesn't sneak into our prior sampling.
rng_objects = set(
var.get_value(borrow=True)
for var in fn.maker.fgraph.variables
if isinstance(var, SharedVariable)
)

assert mu_rv.owner.inputs[0].get_value(borrow=True) in rng_objects
assert Y_rv.owner.inputs[0].get_value(borrow=True) in rng_objects
assert Z_rv.owner.inputs[0].get_value(borrow=True) not in rng_objects

samples_vals = fn()
assert np.shape(np.unique(samples_vals)) == (10,)

0 comments on commit d2fdc5f

Please sign in to comment.