Skip to content

Commit

Permalink
Introduce make_inplace helper function for destructive rewrites
Browse files Browse the repository at this point in the history
Refactor cholesky destructive re-write to use `make_inplace` helper
  • Loading branch information
jessegrabowski committed Jan 9, 2024
1 parent 5d13604 commit 470ea60
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 58 deletions.
12 changes: 12 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ def alloc_like(
return rval


def make_inplace(node, inplace_prop="inplace"):
op = getattr(node.op, "core_op", node.op)
props = op._props_dict()
if props[inplace_prop]:
return False

props[inplace_prop] = True
inplace_op = type(op)(**props)

return inplace_op.make_node(*node.inputs).outputs


def register_useless(
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
):
Expand Down
11 changes: 3 additions & 8 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import (
make_inplace,
register_canonicalize,
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -313,15 +314,9 @@ def local_log_prod_sqr(fgraph, node):
# returns the sign of the prod multiplication.


cholesky_no_inplace = Cholesky(overwrite_a=False)
cholesky_inplace = Cholesky(overwrite_a=True)


@node_rewriter([cholesky_no_inplace], inplace=True)
@node_rewriter([Cholesky], inplace=True)
def local_inplace_cholesky(fgraph, node):
new_out = [cholesky_inplace(*node.inputs)]
copy_stack_trace(node.outputs, new_out)
return new_out
return make_inplace(node, "overwrite_a")


# After destroyhandler(49.5) but before we try to make elemwise things
Expand Down
70 changes: 22 additions & 48 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@


class Cholesky(Op):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
lower : bool, default=True
Whether to return the lower or upper cholesky factor
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing
nans instead.
"""

# TODO: for specific dtypes
# TODO: LAPACK wrapper with in-place behavior, for solve also

Expand All @@ -36,13 +52,11 @@ class Cholesky(Op):

def __init__(self, *, lower=True, on_error="raise", overwrite_a=False):
self.lower = lower

self.overwrite_a = overwrite_a
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error

self.overwrite_a = overwrite_a
if self.overwrite_a:
if overwrite_a:
self.destroy_map = {0: [0]}

def infer_shape(self, fgraph, node, shapes):
Expand Down Expand Up @@ -73,7 +87,7 @@ def perform(self, node, inputs, outputs):
if self.on_error == "raise":
raise
else:
x = np.full_like(x, np.nan)
x = np.zeros(x.shape) * np.nan
z[0] = x.astype(input_dtype)

def L_op(self, inputs, outputs, gradients):
Expand Down Expand Up @@ -129,49 +143,9 @@ def conjugate_solve_triangular(outer, inner):


def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
lower : bool, default=True
Whether to return the lower or upper cholesky factor
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing
nans instead.
overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky.
Returns
-------
TensorVariable
Lower or upper triangular Cholesky factor of `x`
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
import numpy as np
x = pt.tensor('x', size=(5, 5), dtype='float64')
L = pt.linalg.cholesky(x)
f = pytensor.function([x], L)
x_value = np.random.normal(size=(5, 5))
x_value = x_value @ x_value.T # Ensures x is positive definite
L_value = f(x_value)
print(np.allclose(L_value @ L_value.T, x_value))
>>> True
"""

return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
return Blockwise(Cholesky(lower=lower, on_error=on_error, overwrite_a=overwrite_a))(
x
)


class SolveBase(Op):
Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,12 @@ def test_local_inplace_cholesky():
L = cholesky(X, overwrite_a=False, lower=True)
f = function([pytensor.In(X, mutable=True)], L)

assert not L.owner.op.core_op.destructive
assert not L.owner.op.core_op.overwrite_a

nodes = f.maker.fgraph.toposort()
for node in nodes:
if isinstance(node, Cholesky):
assert node.destructive
assert node.overwrite_a
break

X_val = np.random.normal(size=(10, 10)).astype(config.floatX)
Expand Down

0 comments on commit 470ea60

Please sign in to comment.