diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index d34966775a..96f4daefba 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -12,8 +12,11 @@ from pytensor.scalar.basic import Mul from pytensor.tensor.basic import ( AllocDiag, + ExtractDiag, Eye, TensorVariable, + concatenate, + diag, diagonal, ) from pytensor.tensor.blas import Dot22 @@ -29,6 +32,7 @@ inv, kron, pinv, + slogdet, svd, ) from pytensor.tensor.rewriting.basic import ( @@ -701,3 +705,116 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): non_eye_input = pt.shape_padaxis(non_eye_diag, -2) return [eye_input / non_eye_input] + + +@register_canonicalize +@register_stabilize +@node_rewriter([ExtractDiag]) +def rewrite_diag_blockdiag(fgraph, node): + """ + This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. + + diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner block_diag operation + potential_block_diag = node.inputs[0].owner + if not ( + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + submatrices = potential_block_diag.inputs + submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))] + + return [concatenate(submatrices_diag)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def rewrite_det_blockdiag(fgraph, node): + """ + This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. + + det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner block_diag operation + potential_block_diag = node.inputs[0].owner + if not ( + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + sub_matrices = potential_block_diag.inputs + det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))] + + return [prod(det_sub_matrices)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([slogdet]) +def rewrite_slogdet_blockdiag(fgraph, node): + """ + This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + + slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner block_diag operation + potential_block_diag = node.inputs[0].owner + if not ( + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + sub_matrices = potential_block_diag.inputs + sign_sub_matrices, logdet_sub_matrices = zip( + *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] + ) + + return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 0bee56eb30..133e8d6a31 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -662,3 +662,92 @@ def test_inv_diag_from_diag(inv_op): atol=ATOL, rtol=RTOL, ) + + +def test_diag_blockdiag_rewrite(): + n_matrices = 10 + matrix_size = (5, 5) + sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) + bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) + diag_output = pt.diag(bd_output) + f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) + + # Value Test + sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX) + bd_output_test = scipy.linalg.block_diag( + *[sub_matrices_test[i] for i in range(n_matrices)] + ) + diag_output_test = np.diag(bd_output_test) + rewritten_val = f_rewritten(sub_matrices_test) + assert_allclose( + diag_output_test, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_det_blockdiag_rewrite(): + n_matrices = 100 + matrix_size = (5, 5) + sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) + bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) + det_output = pt.linalg.det(bd_output) + f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) + + # Value Test + sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX) + bd_output_test = scipy.linalg.block_diag( + *[sub_matrices_test[i] for i in range(n_matrices)] + ) + det_output_test = np.linalg.det(bd_output_test) + rewritten_val = f_rewritten(sub_matrices_test) + assert_allclose( + det_output_test, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_slogdet_blockdiag_rewrite(): + n_matrices = 100 + matrix_size = (5, 5) + sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) + bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) + sign_output, logdet_output = pt.linalg.slogdet(bd_output) + f_rewritten = function( + [sub_matrices], [sign_output, logdet_output], mode="FAST_RUN" + ) + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) + + # Value Test + sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX) + bd_output_test = scipy.linalg.block_diag( + *[sub_matrices_test[i] for i in range(n_matrices)] + ) + sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test) + rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test) + assert_allclose( + sign_output_test, + rewritten_sign_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + logdet_output_test, + rewritten_logdet_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )