Skip to content

Commit

Permalink
Adding rewrites involving kronecker product (pymc-devs#975)
Browse files Browse the repository at this point in the history
* Added rewrite for diag of kronecker product

* Added rewrite for slogdet; added docstrings for  rewrites

* fixed typo
  • Loading branch information
tanish1729 authored Oct 8, 2024
1 parent 5632777 commit 3e98b9f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 1 deletion.
71 changes: 70 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
Expand Down Expand Up @@ -818,3 +818,72 @@ def rewrite_slogdet_blockdiag(fgraph, node):
)

return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
def rewrite_diag_kronecker(fgraph, node):
"""
This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector.
diag(kron(a,b)) -> outer(diag(a), diag(b))
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner kron operation
potential_kron = node.inputs[0].owner
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return None

# Find the matrices
a, b = potential_kron.inputs
diag_a, diag_b = diag(a), diag(b)
outer_prod_as_vector = outer(diag_a, diag_b).flatten()

return [outer_prod_as_vector]


@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_kronecker(fgraph, node):
"""
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner kron operation
potential_kron = node.inputs[0].owner
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return None

# Find the matrices
a, b = potential_kron.inputs
signs, logdets = zip(*[slogdet(a), slogdet(b)])
sizes = [a.shape[-1], b.shape[-1]]
prod_sizes = prod(sizes, no_zeros_in_input=True)
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]

return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
52 changes: 52 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,55 @@ def test_slogdet_blockdiag_rewrite():
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_diag_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
diag_kron_prod = pt.diag(kron_prod)
f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)

# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
diag_kron_prod_test = np.diag(kron_prod_test)
rewritten_val = f_rewritten(a_test, b_test)
assert_allclose(
diag_kron_prod_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_slogdet_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)

# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
logdet_output_test,
rewritten_logdet_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

0 comments on commit 3e98b9f

Please sign in to comment.