Skip to content

Commit

Permalink
Fix bug with dummy output clients in local_det_chol rewrite (#393)
Browse files Browse the repository at this point in the history
* check for dummy outputs in local_det_chol rewrite

* add rewrite check to 2nd test case

* fix test
  • Loading branch information
jessegrabowski authored Jul 20, 2023
1 parent 82aeefc commit b2b7e28
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def local_det_chol(fgraph, node):
if isinstance(node.op, Det):
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Cholesky):
L = cl.outputs[0]
return [prod(at.extract_diag(L) ** 2)]
Expand Down
18 changes: 17 additions & 1 deletion tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytensor.configdefaults import config
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.type import dmatrix, matrix, vector
Expand Down Expand Up @@ -202,3 +202,19 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
f(Av),
)
)


def test_local_det_chol():
X = matrix("X")
L = at.linalg.cholesky(X)
det_X = at.linalg.det(X)

f = function([X], [L, det_X])

nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)

# This previously raised an error (issue #392)
f = function([X], [L, det_X, X])
nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)

0 comments on commit b2b7e28

Please sign in to comment.