diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 3ab2960562..798d590d7f 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -887,3 +887,82 @@ def rewrite_slogdet_kronecker(fgraph, node): logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_remove_useless_cholesky(fgraph, node): + """ + This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself + + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + # Check whether input to Cholesky is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and hasattr(potential_eye.owner.inputs[-1], "data") + and potential_eye.owner.inputs[-1].data.item() == 0 + ): + return None + return [potential_eye] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + [input] = node.inputs + # Check for use of pt.diag first + if ( + input.owner + and isinstance(input.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(input.owner) + ): + diag_input = input.owner.inputs[0] + cholesky_val = pt.diag(diag_input**0.5) + return [cholesky_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(input) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + [non_eye_input] = non_eye_inputs + + # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_input.type.broadcastable[-2:] == (False, False): + non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2) + if eye_input.type.ndim > 2: + non_eye_input = pt.shape_padaxis(non_eye_input, -2) + + return [eye_input * (non_eye_input**0.5)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 211facb484..9dd2a247a8 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -803,3 +803,106 @@ def test_slogdet_kronecker_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_cholesky_eye_rewrite(): + x = pt.eye(10) + L = pt.linalg.cholesky(x) + f_rewritten = function([], L, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + # Rewrite Test + assert not any(isinstance(node.op, Cholesky) for node in nodes) + + # Value Test + x_test = np.eye(10) + L = np.linalg.cholesky(x_test) + rewritten_val = f_rewritten() + + assert_allclose( + L, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (7, 7), (5, 7, 7)], + ids=["scalar", "vector", "matrix", "batched"], +) +def test_cholesky_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x + # Performing cholesky decomposition using pt.linalg.cholesky + z_cholesky = pt.linalg.cholesky(y) + + # REWRITE TEST + f_rewritten = function([x], z_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Cholesky) for node in nodes) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + cholesky_val = np.linalg.cholesky(x_test_matrix) + rewritten_val = f_rewritten(x_test) + + assert_allclose( + cholesky_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_cholesky_diag_from_diag(): + x = pt.dvector("x") + x_diag = pt.diag(x) + x_cholesky = pt.linalg.cholesky(x_diag) + + # REWRITE TEST + f_rewritten = function([x], x_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert not any(isinstance(node.op, Cholesky) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.random.rand(10) + x_test_matrix = np.eye(10) * x_test + cholesky_val = np.linalg.cholesky(x_test_matrix) + rewritten_cholesky = f_rewritten(x_test) + + assert_allclose( + cholesky_val, + rewritten_cholesky, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): + # Case 1 : y is not a diagonal matrix because of k = -1 + x = pt.tensor("x", shape=(7, 7)) + y = pt.eye(7, k=-1) * x + z_cholesky = pt.linalg.cholesky(y) + + # REWRITE TEST (should not be applied) + f_rewritten = function([x], z_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert any(isinstance(node.op, Cholesky) for node in nodes) + + # Case 2 : eye is degenerate + x = pt.scalar("x") + y = pt.eye(1) * x + z_cholesky = pt.linalg.cholesky(y) + f_rewritten = function([x], z_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert any(isinstance(node.op, Cholesky) for node in nodes)