Skip to content

Commit

Permalink
Add destructive in-place rewrite for pt.linalg.cholesky
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jan 9, 2024
1 parent e180927 commit 5d13604
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 27 deletions.
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,6 @@ def softplus(x):
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower

out_dtype = node.outputs[0].type.numpy_dtype

if lower:
Expand Down
27 changes: 26 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import cast

from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -310,3 +311,27 @@ def local_log_prod_sqr(fgraph, node):

# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.


cholesky_no_inplace = Cholesky(overwrite_a=False)
cholesky_inplace = Cholesky(overwrite_a=True)


@node_rewriter([cholesky_no_inplace], inplace=True)
def local_inplace_cholesky(fgraph, node):
new_out = [cholesky_inplace(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out


# After destroyhandler(49.5) but before we try to make elemwise things
# inplace (75)
linalg_opt_inplace = in2out(local_inplace_cholesky, name="linalg_opt_inplace")
optdb.register(
"InplaceLinalgOpt",
linalg_opt_inplace,
"fast_run",
"inplace",
"linalg_opt_inplace",
position=69.0,
)
91 changes: 66 additions & 25 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,23 @@


class Cholesky(Op):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
lower : bool, default=True
Whether to return the lower or upper cholesky factor
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing
nans instead.
"""

# TODO: inplace
# TODO: for specific dtypes
# TODO: LAPACK wrapper with in-place behavior, for solve also

__props__ = ("lower", "destructive", "on_error")
__props__ = ("lower", "overwrite_a", "on_error")
gufunc_signature = "(m,m)->(m,m)"

def __init__(self, *, lower=True, on_error="raise"):
def __init__(self, *, lower=True, on_error="raise", overwrite_a=False):
self.lower = lower
self.destructive = False

if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error

self.overwrite_a = overwrite_a
if self.overwrite_a:
self.destroy_map = {0: [0]}

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

Expand All @@ -67,15 +54,27 @@ def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
x = inputs[0]
z = outputs[0]
(x,) = inputs
(z,) = outputs
input_dtype = x.dtype
try:
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
if x.flags["C_CONTIGUOUS"] and self.overwrite_a:
# Inputs to the LAPACK functions need to be exactly as expected for overwrite_a to work correctly,
# see https://github.com/scipy/scipy/issues/8155#issuecomment-343996798
x = scipy.linalg.cholesky(
x.T, lower=not self.lower, overwrite_a=self.overwrite_a
).T
else:
x = scipy.linalg.cholesky(
x, lower=self.lower, overwrite_a=self.overwrite_a
)

except scipy.linalg.LinAlgError:
if self.on_error == "raise":
raise
else:
z[0] = (np.zeros(x.shape) * np.nan).astype(x.dtype)
x = np.full_like(x, np.nan)
z[0] = x.astype(input_dtype)

def L_op(self, inputs, outputs, gradients):
"""
Expand Down Expand Up @@ -129,7 +128,49 @@ def conjugate_solve_triangular(outer, inner):
return [grad]


def cholesky(x, lower=True, on_error="raise"):
def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
lower : bool, default=True
Whether to return the lower or upper cholesky factor
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing
nans instead.
overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky.
Returns
-------
TensorVariable
Lower or upper triangular Cholesky factor of `x`
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
import numpy as np
x = pt.tensor('x', size=(5, 5), dtype='float64')
L = pt.linalg.cholesky(x)
f = pytensor.function([x], L)
x_value = np.random.normal(size=(5, 5))
x_value = x_value @ x_value.T # Ensures x is positive definite
L_value = f(x_value)
print(np.allclose(L_value @ L_value.T, x_value))
>>> True
"""

return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)


Expand Down
36 changes: 36 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,39 @@ def test_invalid_batched_a(self):
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)


@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="inplace rewrites disabled when mode is FAST_COMPILE",
)
def test_local_inplace_cholesky():
X = matrix("X")
L = cholesky(X, overwrite_a=False, lower=True)
f = function([pytensor.In(X, mutable=True)], L)

assert not L.owner.op.core_op.destructive

nodes = f.maker.fgraph.toposort()
for node in nodes:
if isinstance(node, Cholesky):
assert node.destructive
break

X_val = np.random.normal(size=(10, 10)).astype(config.floatX)
X_val_in = X_val @ X_val.T
X_val_in_copy = X_val_in.copy()
f(X_val_in)

assert_allclose(
X_val_in[np.triu_indices_from(X_val_in, k=1)],
0.0,
atol=1e-4 if config.floatX == "float32" else 1e-8,
rtol=1e-4 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
X_val_in @ X_val_in.T,
X_val_in_copy,
atol=1e-4 if config.floatX == "float32" else 1e-8,
rtol=1e-4 if config.floatX == "float32" else 1e-8,
)

0 comments on commit 5d13604

Please sign in to comment.