Skip to content

Commit

Permalink
Add a function to generate prior samples
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Mar 23, 2023
1 parent ab429b2 commit 51db0a2
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 2 deletions.
31 changes: 31 additions & 0 deletions aemcmc/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import aesara
import aesara.tensor as at
import aesara.tensor.random as ar

from aemcmc.utils import get_rv_updates


def sample_prior(
srng: ar.RandomStream, num_samples: at.TensorVariable, *rvs: at.TensorVariable
) -> at.TensorVariable:
"""Sample from a model's prior distributions.
Parameters
----------
srng:
`RandomStream` instance with which the model was defined.
num_samples:
The number of prior samples to generate.
rvs:
The random variables whose prior distribution we want to sample.
"""

rv_updates = get_rv_updates(srng, *rvs)

def step_fn():
return rvs, rv_updates

samples, updates = aesara.scan(step_fn, n_steps=num_samples)

return samples, updates
58 changes: 56 additions & 2 deletions aemcmc/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple

from aesara.graph.basic import Constant, Variable
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant, Variable, ancestors
from aesara.tensor import as_tensor_variable
from aesara.tensor.random.type import RandomType
from aesara.tensor.var import TensorVariable

if TYPE_CHECKING:
from aesara.tensor.random.utils import RandomStream


@dataclass(frozen=True)
class ModelInfo:
Expand Down Expand Up @@ -73,3 +78,52 @@ def remove_constants(inputs):
res.append(inp_t)

return res


def get_rv_updates(
srng: "RandomStream", *rvs: TensorVariable
) -> Dict[SharedVariable, "Variable"]:
r"""Get the updates needed to update RNG objects during sampling of `rvs`.
A search is performed over `rvs` for `SharedVariable`\s with default
updates and the updates stored in `srng`.
Parameters
----------
srng:
`RandomStream` instance with which the model was defined.
rvs:
The random variables whose prior distribution we want to sample.
Returns
-------
A dict containing the updates needed to sample from the models given by
`rvs`.
"""
# TODO: It's kind of weird that this is an alist-like data structure; we
# should revisit this in `RandomStream`
srng_updates = dict(srng.state_updates)
rv_updates = {}

for var in ancestors(rvs):
if not isinstance(var, SharedVariable) and not isinstance(var.type, RandomType):
continue

# TODO: Consider making sure the updates correspond to "in-place"
# updates of the RNGs for relevant `RandomVariable`s?
# More generally, a function like this could be used to determine the
# consistency of `RandomVariable` updates in general (e.g. find
# bad/disassociated updates).
srng_update = srng_updates.get(var)

if var.default_update:
if srng_update:
assert srng_update == var.default_update

# We prefer the default update (for no particular reason)
rv_updates[var] = var.default_update
elif srng_update:
rv_updates[var] = srng_update

return rv_updates
50 changes: 50 additions & 0 deletions tests/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import aesara
import aesara.tensor as at
import numpy as np
from aesara.compile.sharedvalue import SharedVariable

from aemcmc.sample import sample_prior


def test_sample_prior():
srng = at.random.RandomStream(123)

mu_rv = srng.normal(0, 1, name="mu")
Y_rv = srng.normal(mu_rv, 1.0, name="Y")
Z_rv = srng.gamma(0.5, 0.5, name="Z")

samples, updates = sample_prior(srng, 10, Y_rv)
fn = aesara.function([], samples, updates=updates)

# Make sure that `Z_rv` doesn't sneak into our prior sampling.
rng_objects = set(
var.get_value(borrow=True)
for var in fn.maker.fgraph.variables
if isinstance(var, SharedVariable)
)

assert mu_rv.owner.inputs[0].get_value(borrow=True) in rng_objects
assert Y_rv.owner.inputs[0].get_value(borrow=True) in rng_objects
assert Z_rv.owner.inputs[0].get_value(borrow=True) not in rng_objects

samples_vals = fn()
assert np.shape(np.unique(samples_vals)) == (10,)

# Try it again, but without a default update
Y_rv.owner.inputs[0].default_update = None

samples, updates = sample_prior(srng, 10, Y_rv)
fn = aesara.function([], samples, updates=updates)

rng_objects = set(
var.get_value(borrow=True)
for var in fn.maker.fgraph.variables
if isinstance(var, SharedVariable)
)

assert mu_rv.owner.inputs[0].get_value(borrow=True) in rng_objects
assert Y_rv.owner.inputs[0].get_value(borrow=True) in rng_objects
assert Z_rv.owner.inputs[0].get_value(borrow=True) not in rng_objects

samples_vals = fn()
assert np.shape(np.unique(samples_vals)) == (10,)

0 comments on commit 51db0a2

Please sign in to comment.