Skip to content

Commit

Permalink
Add rewrite for sum of normal RVs
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama committed Apr 9, 2023
1 parent 58b57c2 commit 4f4a30b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 11 deletions.
1 change: 1 addition & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# isort: off
# Add rewrites to the DBs
import aeppl.censoring
import aeppl.convolutions
import aeppl.cumsum
import aeppl.mixture
import aeppl.scan
Expand Down
56 changes: 56 additions & 0 deletions aeppl/convolutions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import aesara
import aesara.tensor as at
from aesara.graph.rewriting.basic import EquilibriumGraphRewriter, node_rewriter
from aesara.scalar.basic import Add
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.basic import NormalRV, normal

from aeppl.rewriting import logprob_rewrites_db


@node_rewriter((Elemwise,))
def add_independent_normals(fgraph, node):
if not isinstance(node.op.scalar_op, Add):
return None

X_rv, Y_rv = node.inputs

if not (X_rv.owner and Y_rv.owner) or not (
isinstance(X_rv.owner.op, NormalRV) and isinstance(Y_rv.owner.op, NormalRV)
):
return None

old_rv = node.outputs[0]

mu_x, sigma_x, mu_y, sigma_y, _ = at.broadcast_arrays(
*(X_rv.owner.inputs[-2:] + Y_rv.owner.inputs[-2:] + [old_rv])
)

new_rng = X_rv.owner.inputs[0]

new_node = normal.make_node(
new_rng,
old_rv.shape,
old_rv.dtype,
mu_x + mu_y,
at.sqrt(sigma_x**2 + sigma_y**2),
)

# new_rng must be updated with values of the RNGs output by `new_node
new_rng.default_update = new_node.outputs[0]
new_normal_rv = new_node.default_output()

if old_rv.name:
new_normal_rv.name = old_rv.name

return [new_normal_rv]


logprob_rewrites_db.register(
"add_independent_normals",
EquilibriumGraphRewriter(
[add_independent_normals],
max_use_ratio=aesara.config.optdb__max_use_ratio,
),
"basic",
)
90 changes: 90 additions & 0 deletions tests/test_convolutions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import aesara.tensor as at
import numpy as np
import pytest
from aesara.tensor.random.basic import NormalRV

from aeppl.rewriting import construct_ir_fgraph


@pytest.mark.parametrize(
"mu_x, mu_y, sigma_x, sigma_y, x_shape, y_shape",
[
(
np.array([1, 10, 100]),
np.array(2),
np.array(0.03),
np.tile(0.04, 3),
(),
(),
),
(
np.array([1, 10, 100]),
np.array(2),
np.array(0.03),
np.full((5, 1), 0.04),
(),
(5, 3),
),
(
np.array([[1, 10, 100]]),
np.array([[0.2], [2], [20], [200], [2000]]),
np.array(0.03),
np.array(0.04),
(),
(),
),
(
np.broadcast_to(np.array([1, 10, 100]), (5, 3)),
np.array([2, 20, 200]),
np.array(0.03),
np.array(0.04),
(2, 5, 3),
(),
),
(
np.array([[1, 10, 100]]),
np.array([[0.2], [2], [20], [200], [2000]]),
np.array([[0.5], [5], [50], [500], [5000]]),
np.array([[0.4, 4, 40]]),
(2, 5, 3),
(),
),
(
np.array(1),
np.array(2),
np.array(3),
np.array(4),
(5, 1),
(1,),
),
],
)
def test_add_independent_normals(mu_x, mu_y, sigma_x, sigma_y, x_shape, y_shape):
srng = at.random.RandomStream(29833)

X_rv = srng.normal(mu_x, sigma_x, size=x_shape)
X_rv.name = "X"

Y_rv = srng.normal(mu_y, sigma_y, size=y_shape)
Y_rv.name = "Y"

Z_rv = X_rv + Y_rv
Z_rv.name = "Z"
z_vv = Z_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv})

new_rv = fgraph.outputs[0].owner.inputs[0]

new_rv_mu = mu_x + mu_y
new_rv_sigma = np.sqrt(sigma_x**2 + sigma_y**2)

new_rv_shape = np.broadcast_shapes(new_rv_mu.shape, new_rv_sigma.shape, x_shape, y_shape)

new_rv_mu = np.broadcast_to(new_rv_mu, new_rv_shape)
new_rv_sigma = np.broadcast_to(new_rv_sigma, new_rv_shape)

assert isinstance(new_rv.owner.op, NormalRV)
assert np.allclose(new_rv.owner.inputs[3].eval(), new_rv_mu)
assert np.allclose(new_rv.owner.inputs[4].eval(), new_rv_sigma)
assert new_rv.name == "Z"
11 changes: 0 additions & 11 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,17 +675,6 @@ def test_transformed_rv_and_value():
)


def test_loc_transform_multiple_rvs_fails1():
srng = at.random.RandomStream(0)

x_rv1 = srng.normal(name="x_rv1")
x_rv2 = srng.normal(name="x_rv2")
y_rv = x_rv1 + x_rv2

with pytest.raises(DensityNotFound):
joint_logprob(y_rv)


def test_nested_loc_transform_multiple_rvs_fails2():
srng = at.random.RandomStream(0)

Expand Down

0 comments on commit 4f4a30b

Please sign in to comment.