Skip to content

Commit

Permalink
Adds functions to rewrite cholesky decomposition of identity and diag…
Browse files Browse the repository at this point in the history
…onal matrices (pymc-devs#925)

* fixed merge conflicts

* fixed failing tests and added rewrite for pt.diag

* minor changes; added test to not apply rewrite

* added test for batched case and more cases of not applying rewrite

* minor changes
  • Loading branch information
tanish1729 authored Oct 8, 2024
1 parent 3e98b9f commit be6a032
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 0 deletions.
79 changes: 79 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,82 @@ def rewrite_slogdet_kronecker(fgraph, node):
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]

return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_remove_useless_cholesky(fgraph, node):
"""
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None

# Check whether input to Cholesky is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and hasattr(potential_eye.owner.inputs[-1], "data")
and potential_eye.owner.inputs[-1].data.item() == 0
):
return None
return [potential_eye]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None

[input] = node.inputs
# Check for use of pt.diag first
if (
input.owner
and isinstance(input.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(input.owner)
):
diag_input = input.owner.inputs[0]
cholesky_val = pt.diag(diag_input**0.5)
return [cholesky_val]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(input)
if inputs_or_none is None:
return None

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

[non_eye_input] = non_eye_inputs

# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
if non_eye_input.type.broadcastable[-2:] == (False, False):
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
if eye_input.type.ndim > 2:
non_eye_input = pt.shape_padaxis(non_eye_input, -2)

return [eye_input * (non_eye_input**0.5)]
103 changes: 103 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,106 @@ def test_slogdet_kronecker_rewrite():
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_cholesky_eye_rewrite():
x = pt.eye(10)
L = pt.linalg.cholesky(x)
f_rewritten = function([], L, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

# Rewrite Test
assert not any(isinstance(node.op, Cholesky) for node in nodes)

# Value Test
x_test = np.eye(10)
L = np.linalg.cholesky(x_test)
rewritten_val = f_rewritten()

assert_allclose(
L,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


@pytest.mark.parametrize(
"shape",
[(), (7,), (7, 7), (5, 7, 7)],
ids=["scalar", "vector", "matrix", "batched"],
)
def test_cholesky_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x
# Performing cholesky decomposition using pt.linalg.cholesky
z_cholesky = pt.linalg.cholesky(y)

# REWRITE TEST
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Cholesky) for node in nodes)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
cholesky_val = np.linalg.cholesky(x_test_matrix)
rewritten_val = f_rewritten(x_test)

assert_allclose(
cholesky_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_cholesky_diag_from_diag():
x = pt.dvector("x")
x_diag = pt.diag(x)
x_cholesky = pt.linalg.cholesky(x_diag)

# REWRITE TEST
f_rewritten = function([x], x_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

assert not any(isinstance(node.op, Cholesky) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.random.rand(10)
x_test_matrix = np.eye(10) * x_test
cholesky_val = np.linalg.cholesky(x_test_matrix)
rewritten_cholesky = f_rewritten(x_test)

assert_allclose(
cholesky_val,
rewritten_cholesky,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
# Case 1 : y is not a diagonal matrix because of k = -1
x = pt.tensor("x", shape=(7, 7))
y = pt.eye(7, k=-1) * x
z_cholesky = pt.linalg.cholesky(y)

# REWRITE TEST (should not be applied)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)

# Case 2 : eye is degenerate
x = pt.scalar("x")
y = pt.eye(1) * x
z_cholesky = pt.linalg.cholesky(y)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 comments on commit be6a032

Please sign in to comment.