From 5d13604eb28c1bdd5533508496a4d65fba8918d2 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 22:05:31 +0100 Subject: [PATCH] Add destructive in-place rewrite for `pt.linalg.cholesky` --- pytensor/link/numba/dispatch/basic.py | 1 - pytensor/tensor/rewriting/linalg.py | 27 +++++++- pytensor/tensor/slinalg.py | 91 +++++++++++++++++++-------- tests/tensor/rewriting/test_linalg.py | 36 +++++++++++ 4 files changed, 128 insertions(+), 27 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 9c9c800b92..ad93441f94 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -812,7 +812,6 @@ def softplus(x): @numba_funcify.register(Cholesky) def numba_funcify_Cholesky(op, node, **kwargs): lower = op.lower - out_dtype = node.outputs[0].type.numpy_dtype if lower: diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index bc3eef6fca..84a98beaee 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1,7 +1,8 @@ import logging from typing import cast -from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter +from pytensor.compile import optdb +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise @@ -310,3 +311,27 @@ def local_log_prod_sqr(fgraph, node): # TODO: have a reduction like prod and sum that simply # returns the sign of the prod multiplication. + + +cholesky_no_inplace = Cholesky(overwrite_a=False) +cholesky_inplace = Cholesky(overwrite_a=True) + + +@node_rewriter([cholesky_no_inplace], inplace=True) +def local_inplace_cholesky(fgraph, node): + new_out = [cholesky_inplace(*node.inputs)] + copy_stack_trace(node.outputs, new_out) + return new_out + + +# After destroyhandler(49.5) but before we try to make elemwise things +# inplace (75) +linalg_opt_inplace = in2out(local_inplace_cholesky, name="linalg_opt_inplace") +optdb.register( + "InplaceLinalgOpt", + linalg_opt_inplace, + "fast_run", + "inplace", + "linalg_opt_inplace", + position=69.0, +) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f96dec5a35..5a879cbf29 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -28,36 +28,23 @@ class Cholesky(Op): - """ - Return a triangular matrix square root of positive semi-definite `x`. - - L = cholesky(X, lower=True) implies dot(L, L.T) == X. - - Parameters - ---------- - lower : bool, default=True - Whether to return the lower or upper cholesky factor - on_error : ['raise', 'nan'] - If on_error is set to 'raise', this Op will raise a - `scipy.linalg.LinAlgError` if the matrix is not positive definite. - If on_error is set to 'nan', it will return a matrix containing - nans instead. - """ - - # TODO: inplace # TODO: for specific dtypes # TODO: LAPACK wrapper with in-place behavior, for solve also - __props__ = ("lower", "destructive", "on_error") + __props__ = ("lower", "overwrite_a", "on_error") gufunc_signature = "(m,m)->(m,m)" - def __init__(self, *, lower=True, on_error="raise"): + def __init__(self, *, lower=True, on_error="raise", overwrite_a=False): self.lower = lower - self.destructive = False + if on_error not in ("raise", "nan"): raise ValueError('on_error must be one of "raise" or ""nan"') self.on_error = on_error + self.overwrite_a = overwrite_a + if self.overwrite_a: + self.destroy_map = {0: [0]} + def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -67,15 +54,27 @@ def make_node(self, x): return Apply(self, [x], [x.type()]) def perform(self, node, inputs, outputs): - x = inputs[0] - z = outputs[0] + (x,) = inputs + (z,) = outputs + input_dtype = x.dtype try: - z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) + if x.flags["C_CONTIGUOUS"] and self.overwrite_a: + # Inputs to the LAPACK functions need to be exactly as expected for overwrite_a to work correctly, + # see https://github.com/scipy/scipy/issues/8155#issuecomment-343996798 + x = scipy.linalg.cholesky( + x.T, lower=not self.lower, overwrite_a=self.overwrite_a + ).T + else: + x = scipy.linalg.cholesky( + x, lower=self.lower, overwrite_a=self.overwrite_a + ) + except scipy.linalg.LinAlgError: if self.on_error == "raise": raise else: - z[0] = (np.zeros(x.shape) * np.nan).astype(x.dtype) + x = np.full_like(x, np.nan) + z[0] = x.astype(input_dtype) def L_op(self, inputs, outputs, gradients): """ @@ -129,7 +128,49 @@ def conjugate_solve_triangular(outer, inner): return [grad] -def cholesky(x, lower=True, on_error="raise"): +def cholesky(x, lower=True, on_error="raise", overwrite_a=False): + """ + Return a triangular matrix square root of positive semi-definite `x`. + + L = cholesky(X, lower=True) implies dot(L, L.T) == X. + + Parameters + ---------- + lower : bool, default=True + Whether to return the lower or upper cholesky factor + on_error : ['raise', 'nan'] + If on_error is set to 'raise', this Op will raise a + `scipy.linalg.LinAlgError` if the matrix is not positive definite. + If on_error is set to 'nan', it will return a matrix containing + nans instead. + overwrite_a: bool, ignored + Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only + for consistency with scipy.linalg.cholesky. + + Returns + ------- + TensorVariable + Lower or upper triangular Cholesky factor of `x` + + Example + ------- + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + import numpy as np + + x = pt.tensor('x', size=(5, 5), dtype='float64') + L = pt.linalg.cholesky(x) + + f = pytensor.function([x], L) + x_value = np.random.normal(size=(5, 5)) + x_value = x_value @ x_value.T # Ensures x is positive definite + L_value = f(x_value) + print(np.allclose(L_value @ L_value.T, x_value)) + >>> True + """ + return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9cdb69ce6b..fae63c9dfa 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -306,3 +306,39 @@ def test_invalid_batched_a(self): ref_fn(test_a, test_b), rtol=1e-7 if config.floatX == "float64" else 1e-5, ) + + +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="inplace rewrites disabled when mode is FAST_COMPILE", +) +def test_local_inplace_cholesky(): + X = matrix("X") + L = cholesky(X, overwrite_a=False, lower=True) + f = function([pytensor.In(X, mutable=True)], L) + + assert not L.owner.op.core_op.destructive + + nodes = f.maker.fgraph.toposort() + for node in nodes: + if isinstance(node, Cholesky): + assert node.destructive + break + + X_val = np.random.normal(size=(10, 10)).astype(config.floatX) + X_val_in = X_val @ X_val.T + X_val_in_copy = X_val_in.copy() + f(X_val_in) + + assert_allclose( + X_val_in[np.triu_indices_from(X_val_in, k=1)], + 0.0, + atol=1e-4 if config.floatX == "float32" else 1e-8, + rtol=1e-4 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + X_val_in @ X_val_in.T, + X_val_in_copy, + atol=1e-4 if config.floatX == "float32" else 1e-8, + rtol=1e-4 if config.floatX == "float32" else 1e-8, + )