Skip to content

Commit

Permalink
Implement truncated variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 7, 2022
1 parent 6b58a43 commit 982e469
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 3 deletions.
6 changes: 6 additions & 0 deletions aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def xlogy0(m, x):
return at.switch(at.eq(x, 0), at.switch(at.eq(m, 0), 0.0, -np.inf), m * at.log(x))


def logdiffexp(a, b):
"""log(exp(a) - exp(b))"""
# TODO: This should be a basic Aesara stabilization
return a + at.log1mexp(b - a)


def logprob(rv_var, *rv_values, **kwargs):
"""Create a graph for the log-probability of a ``RandomVariable``."""
logprob = _logprob(rv_var.owner.op, rv_values, *rv_var.owner.inputs, **kwargs)
Expand Down
221 changes: 219 additions & 2 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
import warnings
from functools import singledispatch
from typing import List, Optional

import aesara.tensor as at
import aesara.tensor.random.basic as arb
import numpy as np
from aesara import scan, shared
from aesara.compile.builders import OpFromGraph
from aesara.graph import Op
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.raise_op import Assert
from aesara.scalar.basic import Clip
from aesara.scalar.basic import clip as scalar_clip
from aesara.scan import until
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.var import TensorConstant
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant, TensorVariable

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob
from aeppl.logprob import (
CheckParameterValue,
_logcdf,
_logprob,
icdf,
logcdf,
logdiffexp,
)
from aeppl.opt import rv_sinking_db


Expand Down Expand Up @@ -123,3 +138,205 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
)

return logprob


class TruncatedRV(OpFromGraph):
"""An `Op` constructed from an Aesara graph that represents a truncated univariate RV."""

default_output = 1
base_rv_op = None

def __init__(self, base_rv_op: Op, *args, **kwargs):
self.base_rv_op = base_rv_op
super().__init__(*args, **kwargs)


@singledispatch
def _truncated(op: Op, lower, upper, *params):
"""Return the truncated equivalent of another ``RandomVariable``."""
raise NotImplementedError(
f"{op} does not have an equivalent truncated version implemented"
)


def truncate(
rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None
):
"""Truncate a univariate RandomVariable between lower and upper.
If lower or upper is ``None``, the variable is not truncated on that side.
Depending on dispatched implementations, this function returns either a specialized
`Op`, or equivalent graph representing the truncation process, via inverse CDF
sampling, or rejection sampling.
The argument `max_n_steps` controls the maximum number of resamples that are
attempted when performing rejection sampling. An Error is raised if convergence is
not reached after that many steps.
TODO: Add Note about updates being necessary for compilation when graph uses
rejection sampling
"""

lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
if lower is None and upper is None:
raise ValueError("lower and upper cannot both be None")

if not isinstance(rv.owner.op, RandomVariable):
raise ValueError(f"truncation not implemented for Op {rv.owner.op}")

if rv.owner.op.ndim_supp > 0:
raise NotImplementedError(
"truncation not implemented for multivariate variables"
)

if rng is None:
rng = shared(np.random.RandomState(), borrow=True)

# Try to use specialized Op
try:
return _truncated(rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:])
except NotImplementedError:
pass

# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs = [rng, *rv.owner.inputs[1:], lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
*rv_inputs_, lower_, upper_ = graph_inputs_

# Try to use inverted cdf sampling
try:
rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()
cdf_lower_ = at.exp(logcdf(rv_, lower_))
cdf_upper_ = at.exp(logcdf(rv_, upper_))
uniform_ = at.random.uniform(
cdf_lower_,
cdf_upper_,
rng=rv_inputs_[0],
size=rv_inputs_[1],
)
truncated_rv_ = icdf(rv_, uniform_)
return TruncatedRV(
base_rv_op=rv.owner.op,
inputs=graph_inputs_,
outputs=[uniform_.owner.outputs[0], truncated_rv_],
inline=True,
)(*graph_inputs)
except NotImplementedError:
pass

# Fallback to rejection sampling
# TODO: Handle potential broadcast by lower / upper

# Scan forces us to use a shared variable for the RNG
graph_inputs = graph_inputs[1:]
graph_inputs_ = graph_inputs_[1:]
*rv_inputs_, lower_, upper_ = (rng, *graph_inputs_)

rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()

def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
# We need to set default_update for scan to generate updates
next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs
rng.default_update = next_rng

truncated_rv = at.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))

return (truncated_rv, reject_draws), until(~at.any(reject_draws))

(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
at.empty_like(rv_),
at.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)

truncated_rv_ = truncated_rv_[-1]
convergence_ = ~at.any(reject_draws_[-1])
truncated_rv_ = Assert(
f"truncation did not converge in predefined {max_n_steps} steps"
)(truncated_rv_, convergence_)

# TODO: Scan does not return updates when a single step is performed, so this
# will fail with max_n_steps = 1
return TruncatedRV(
base_rv_op=rv.owner.op,
inputs=graph_inputs_,
outputs=[tuple(updates.values())[0], truncated_rv_],
inline=True,
)(*graph_inputs)


@_logprob.register(TruncatedRV)
def truncated_logprob(op, values, *inputs, **kwargs):
# TODO: Check if ajdustment is needed for discrete
(value,) = values

# rng shows up as the last input when using rejection sampling
if op.shared_inputs:
*rv_inputs, lower_bound, upper_bound, rng = inputs
rv_inputs = [rng, *rv_inputs]
else:
*rv_inputs, lower_bound, upper_bound = inputs

base_rv_op = op.base_rv_op
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
lower_logcdf = _logcdf(base_rv_op, lower_bound, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper_bound, *rv_inputs, **kwargs)

if base_rv_op.name:
logp.name = f"{base_rv_op}_logprob"
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"

is_lower_bounded = not (
isinstance(lower_bound, TensorConstant)
and np.all(np.isneginf(lower_bound.value))
)
is_upper_bounded = not (
isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))
)

if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = at.log1mexp(lower_logcdf)
elif is_upper_bounded:
lognorm = upper_logcdf
else:
lognorm = 0

logp = logp - lognorm

if is_lower_bounded:
logp = at.switch(value < lower_bound, -np.inf, logp)

if is_upper_bounded:
logp = at.switch(value <= upper_bound, logp, -np.inf)

if is_lower_bounded and is_upper_bounded:
logp = CheckParameterValue("lower_bound <= upper_bound")(
logp, at.all(at.le(lower_bound, upper_bound))
)

return logp


@_truncated.register(arb.UniformRV)
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
return at.random.uniform(
at.max((lower_orig, lower)),
at.min((upper_orig, upper)),
rng=rng,
size=size,
dtype=dtype,
)
120 changes: 119 additions & 1 deletion tests/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import numpy as np
import pytest
import scipy as sp
import scipy.stats
import scipy.stats as st
from aesara.tensor.random.basic import NormalRV, UniformRV

from aeppl import factorized_joint_logprob, joint_logprob
from aeppl import factorized_joint_logprob, joint_logprob, logprob
from aeppl.logprob import _icdf
from aeppl.transforms import LogTransform, TransformValuesOpt
from aeppl.truncation import TruncatedRV, _truncated, truncate
from tests.utils import assert_no_rvs


Expand Down Expand Up @@ -189,3 +193,117 @@ def test_censored_transform():
)

assert np.isclose(obs_logp, exp_logp)


class IcdfNormalRV(NormalRV):
"""Normal RV that has icdf but not truncated dispatching"""


class RejectionNormalRV(NormalRV):
"Normal RV that has nethir icdf nor truncated dispatching" ""


icdf_normal = IcdfNormalRV()
rejection_normal = RejectionNormalRV()


@_truncated.register(IcdfNormalRV)
@_truncated.register(RejectionNormalRV)
def custom_normal_truncated_not_implemented(*args, **kwargs):
raise NotImplementedError()


@_icdf.register(RejectionNormalRV)
def custom_normal_icdf_not_implemented(*args, **kwargs):
raise NotImplementedError()


def test_truncation_specialized_op():
x = at.random.uniform(0, 10, name="x", size=100)
xt = truncate(x, lower=5, upper=15)

assert isinstance(xt.owner.op, UniformRV)

lower_upper = at.stack(xt.owner.inputs[3:])
assert np.all(lower_upper.eval() == [5, 10])


@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_truncation_sampling(op_type):
loc = 0.15
scale = 10
lower = -1
upper = 1.5

normal_op = icdf_normal if op_type == "icdf" else rejection_normal

x = normal_op(loc, scale, name="x", size=100)
xt = truncate(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, TruncatedRV)

# Check that original op can be used on its own
assert x.eval()

ref_xt = scipy.stats.truncnorm(
(lower - loc) / scale,
(upper - loc) / scale,
loc,
scale,
)

# rng shows up as the last input when using rejection sampling
if op_type == "rejection":
updates = {xt.owner.inputs[-1]: xt.owner.outputs[0]}
else:
updates = {xt.owner.inputs[0]: xt.owner.outputs[0]}

xt_fn = aesara.function([], xt, updates=updates)
xt_draws = np.array([xt_fn() for _ in range(5)])
assert np.all(xt_draws >= lower)
assert np.all(xt_draws <= upper)
assert np.unique(xt_draws).size == xt_draws.size
_, p = scipy.stats.ks_2samp(xt_draws.ravel(), ref_xt.rvs(500))
assert p > 0.001

# Test max_n_steps
xt = truncate(x, lower=lower, upper=upper, max_n_steps=2)
# TODO: Function cannot be compiled without passing updates!
if op_type == "rejection":
updates = {xt.owner.inputs[-1]: xt.owner.outputs[0]}
xt_fn = aesara.function([], xt, updates=updates)

if op_type == "icdf":
xt_draws = xt_fn()
assert np.all(xt_draws >= lower)
assert np.all(xt_draws <= upper)
assert np.unique(xt_draws).size == xt_draws.size
else:
with pytest.raises(AssertionError, match="^truncation did not converge"):
xt_fn()


@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_truncation_rejection_sampling_op_logp(op_type):
loc = 0.15
scale = 10
lower = -1
upper = 1.5

op = icdf_normal if op_type == "icdf" else rejection_normal

x = op(loc, scale, name="x")
xt = truncate(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, TruncatedRV)

ref_xt = scipy.stats.truncnorm(
(lower - loc) / scale,
(upper - loc) / scale,
loc,
scale,
)

xt_v = xt.clone()
xt_logp_fn = aesara.function([xt_v], logprob(xt, xt_v))

for test_xt_v in (lower - 1, lower, lower + 1, upper - 1, upper, upper + 1):
assert np.isclose(xt_logp_fn(test_xt_v), ref_xt.logpdf(test_xt_v))

0 comments on commit 982e469

Please sign in to comment.