Skip to content

Commit

Permalink
Added rewrites involving block diagonal matrices (#967)
Browse files Browse the repository at this point in the history
* added rewrite for diag(block_diag)

* added rewrite for determinant of blockdiag

* Added rewrite for slogdet; added docstrings for all 3 rewrites

* fixed typecasting for tests
  • Loading branch information
tanish1729 authored Oct 8, 2024
1 parent 2086aeb commit 3eea7d0
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
117 changes: 117 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
Eye,
TensorVariable,
concatenate,
diag,
diagonal,
)
from pytensor.tensor.blas import Dot22
Expand All @@ -29,6 +32,7 @@
inv,
kron,
pinv,
slogdet,
svd,
)
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -701,3 +705,116 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)

return [eye_input / non_eye_input]


@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
def rewrite_diag_blockdiag(fgraph, node):
"""
This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices.
diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
submatrices = potential_block_diag.inputs
submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]

return [concatenate(submatrices_diag)]


@register_canonicalize
@register_stabilize
@node_rewriter([det])
def rewrite_det_blockdiag(fgraph, node):
"""
This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values.
det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]

return [prod(det_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_blockdiag(fgraph, node):
"""
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
sign_sub_matrices, logdet_sub_matrices = zip(
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
)

return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
89 changes: 89 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,92 @@ def test_inv_diag_from_diag(inv_op):
atol=ATOL,
rtol=RTOL,
)


def test_diag_blockdiag_rewrite():
n_matrices = 10
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
diag_output = pt.diag(bd_output)
f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)

# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
diag_output_test = np.diag(bd_output_test)
rewritten_val = f_rewritten(sub_matrices_test)
assert_allclose(
diag_output_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_det_blockdiag_rewrite():
n_matrices = 100
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
det_output = pt.linalg.det(bd_output)
f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)

# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
det_output_test = np.linalg.det(bd_output_test)
rewritten_val = f_rewritten(sub_matrices_test)
assert_allclose(
det_output_test,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_slogdet_blockdiag_rewrite():
n_matrices = 100
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
sign_output, logdet_output = pt.linalg.slogdet(bd_output)
f_rewritten = function(
[sub_matrices], [sign_output, logdet_output], mode="FAST_RUN"
)

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)

# Value Test
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
bd_output_test = scipy.linalg.block_diag(
*[sub_matrices_test[i] for i in range(n_matrices)]
)
sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
logdet_output_test,
rewritten_logdet_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

0 comments on commit 3eea7d0

Please sign in to comment.