From af8e86f21713781b1484fb93bac8fbda9184334a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 7 Jan 2024 00:27:09 +0100 Subject: [PATCH 1/8] Implement numba overload for POTRF, LAPACK cholesky routine --- pytensor/link/numba/dispatch/basic.py | 68 ++++++++++----------- pytensor/link/numba/dispatch/slinalg.py | 80 ++++++++++++++++++++++++- tests/link/numba/test_nlinalg.py | 51 ---------------- tests/link/numba/test_slinalg.py | 56 +++++++++++++++++ 4 files changed, 169 insertions(+), 86 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 9c9c800b92..244f2f45e5 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -37,7 +37,7 @@ from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -from pytensor.tensor.slinalg import Cholesky, Solve +from pytensor.tensor.slinalg import Solve from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -809,39 +809,39 @@ def softplus(x): return softplus -@numba_funcify.register(Cholesky) -def numba_funcify_Cholesky(op, node, **kwargs): - lower = op.lower - - out_dtype = node.outputs[0].type.numpy_dtype - - if lower: - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_njit - def cholesky(a): - return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype) - - else: - # TODO: Use SciPy's BLAS/LAPACK Cython wrappers. - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`lower` argument to `scipy.linalg.cholesky`." - ), - UserWarning, - ) - - ret_sig = get_numba_type(node.outputs[0].type) - - @numba_njit - def cholesky(a): - with numba.objmode(ret=ret_sig): - ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype) - return ret - - return cholesky +# @numba_funcify.register(Cholesky) +# def numba_funcify_Cholesky(op, node, **kwargs): +# lower = op.lower +# +# out_dtype = node.outputs[0].type.numpy_dtype +# +# if lower: +# inputs_cast = int_to_float_fn(node.inputs, out_dtype) +# +# @numba_njit +# def cholesky(a): +# return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype) +# +# else: +# # TODO: Use SciPy's BLAS/LAPACK Cython wrappers. +# +# warnings.warn( +# ( +# "Numba will use object mode to allow the " +# "`lower` argument to `scipy.linalg.cholesky`." +# ), +# UserWarning, +# ) +# +# ret_sig = get_numba_type(node.outputs[0].type) +# +# @numba_njit +# def cholesky(a): +# with numba.objmode(ret=ret_sig): +# ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype) +# return ret +# +# return cholesky @numba_funcify.register(Solve) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index ad8065defd..28e797d01d 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -9,7 +9,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import SolveTriangular +from pytensor.tensor.slinalg import Cholesky, SolveTriangular _PTR = ctypes.POINTER @@ -177,6 +177,22 @@ def numba_xtrtrs(cls, dtype): return functype(lapack_ptr) + @classmethod + def numba_xpotrf(cls, dtype): + """ + Called by scipy.linalg.cholesky + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO, + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # INFO + ) + return functype(lapack_ptr) + def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): return linalg.solve_triangular( @@ -273,3 +289,65 @@ def solve_triangular(a, b): return res return solve_triangular + + +def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): + return linalg.cholesky( + a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite + ) + + +@overload(_cholesky) +def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): + ensure_lapack() + _check_scipy_linalg_matrix(A, "cholesky") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_potrf = _LAPACK().numba_xpotrf(dtype) + + def impl(A, lower=0, overwrite_a=False, check_finite=True): + _N = np.int32(A.shape[-1]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + numba_potrf( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + INFO, + ) + + return A_copy + + return impl + + +@numba_funcify.register(Cholesky) +def numba_funcify_Cholesky(op, node, **kwargs): + lower = op.lower + # overwrite_a = op.overwrite_a + overwrite_a = False + check_finite = op.check_finite + + @numba_basic.numba_njit(inline="always") + def nb_cholesky(a): + if check_finite: + if np.any(np.isinf(a)) or np.any(np.isnan(a)): + raise ValueError( + "Non-numeric values (nan or inf) in input to ", op.name + ) + res = _cholesky(a, lower, overwrite_a, check_finite) + return res + + return nb_cholesky diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 4732a8f3d0..5e25bbf53d 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -14,57 +14,6 @@ rng = np.random.default_rng(42849) -@pytest.mark.parametrize( - "x, lower, exc", - [ - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - True, - None, - ), - ( - set_test_value( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - True, - None, - ), - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - False, - UserWarning, - ), - ], -) -def test_Cholesky(x, lower, exc): - g = slinalg.Cholesky(lower=lower)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - @pytest.mark.parametrize( "A, x, lower, exc", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 75e016f1e0..7376c8bb33 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,3 +1,4 @@ +import contextlib import re import numpy as np @@ -6,6 +7,10 @@ import pytensor import pytensor.tensor as pt from pytensor import config +from pytensor.compile import SharedVariable +from pytensor.graph import Constant, FunctionGraph +from tests.link.numba.test_basic import compare_numba_and_py +from tests.tensor.test_extra_ops import set_test_value numba = pytest.importorskip("numba") @@ -102,3 +107,54 @@ def test_solve_triangular_raises_on_nan_inf(value): ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") ): f(A_tri, b) + + +@pytest.mark.parametrize( + "x, lower, exc", + [ + ( + set_test_value( + pt.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + True, + None, + ), + ( + set_test_value( + pt.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + True, + None, + ), + ( + set_test_value( + pt.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + False, + UserWarning, + ), + ], +) +def test_Cholesky(x, lower, exc): + g = pt.linalg.cholesky(x, lower=lower) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) From a7858e3e45d054edd7e45bb6925b0488429fe800 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 7 Jan 2024 00:30:43 +0100 Subject: [PATCH 2/8] Delete old numba_funcify_Cholesky --- pytensor/link/numba/dispatch/basic.py | 35 --------------------------- 1 file changed, 35 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 244f2f45e5..8ac3f5680d 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -809,41 +809,6 @@ def softplus(x): return softplus -# @numba_funcify.register(Cholesky) -# def numba_funcify_Cholesky(op, node, **kwargs): -# lower = op.lower -# -# out_dtype = node.outputs[0].type.numpy_dtype -# -# if lower: -# inputs_cast = int_to_float_fn(node.inputs, out_dtype) -# -# @numba_njit -# def cholesky(a): -# return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype) -# -# else: -# # TODO: Use SciPy's BLAS/LAPACK Cython wrappers. -# -# warnings.warn( -# ( -# "Numba will use object mode to allow the " -# "`lower` argument to `scipy.linalg.cholesky`." -# ), -# UserWarning, -# ) -# -# ret_sig = get_numba_type(node.outputs[0].type) -# -# @numba_njit -# def cholesky(a): -# with numba.objmode(ret=ret_sig): -# ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype) -# return ret -# -# return cholesky - - @numba_funcify.register(Solve) def numba_funcify_Solve(op, node, **kwargs): assume_a = op.assume_a From 27614068a696e7b72314fc99042fb83defda451c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 9 Jan 2024 22:51:46 +0100 Subject: [PATCH 3/8] Refactor tests to include supported keywords and datatypes --- pytensor/link/numba/dispatch/slinalg.py | 22 ++++--- tests/link/numba/test_slinalg.py | 77 ++++++++++--------------- 2 files changed, 43 insertions(+), 56 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 28e797d01d..37e5f592f1 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -25,6 +25,15 @@ _ptr_int = _PTR(_int) +@numba.core.extending.register_jitable +def _check_finite_matrix(a, func_name): + for v in np.nditer(a): + if not np.isfinite(v.item()): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input to " + func_name + ) + + @intrinsic def val_to_dptr(typingctx, data): def impl(context, builder, signature, args): @@ -310,6 +319,9 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True): if A.shape[-2] != _N: raise linalg.LinAlgError("Last 2 dimensions of A must be square") + if check_finite: + _check_finite_matrix(A, "cholesky") + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) N = val_to_int_ptr(_N) LDA = val_to_int_ptr(_N) @@ -336,18 +348,12 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True): @numba_funcify.register(Cholesky) def numba_funcify_Cholesky(op, node, **kwargs): lower = op.lower - # overwrite_a = op.overwrite_a overwrite_a = False - check_finite = op.check_finite + on_error = op.on_error @numba_basic.numba_njit(inline="always") def nb_cholesky(a): - if check_finite: - if np.any(np.isinf(a)) or np.any(np.isnan(a)): - raise ValueError( - "Non-numeric values (nan or inf) in input to ", op.name - ) - res = _cholesky(a, lower, overwrite_a, check_finite) + res = _cholesky(a, lower, overwrite_a, on_error) return res return nb_cholesky diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 7376c8bb33..ac872d075f 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,4 +1,3 @@ -import contextlib import re import numpy as np @@ -109,52 +108,34 @@ def test_solve_triangular_raises_on_nan_inf(value): f(A_tri, b) -@pytest.mark.parametrize( - "x, lower, exc", - [ - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - True, - None, - ), - ( - set_test_value( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - True, - None, - ), - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - False, - UserWarning, - ), - ], -) -def test_Cholesky(x, lower, exc): +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +def test_numba_Cholesky(lower): + x = set_test_value( + pt.tensor(dtype=config.floatX, shape=(3, 3)), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)), + ) + g = pt.linalg.cholesky(x, lower=lower) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_numba_Cholesky_raises_on_nan(): + test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value[0, 0] = np.nan + + x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = x.T.dot(x) + g = pt.linalg.cholesky(x, on_error="raise") + f = pytensor.function([x], g, mode="NUMBA") - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) + with pytest.raises(ValueError, match=r"Non-numeric values"): + f(test_value) From 69af4ccb1c20aac2972df9da709ab2f2574db54f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 10 Jan 2024 20:40:44 +0100 Subject: [PATCH 4/8] Validate inputs and outputs of numba cholesky function --- pytensor/link/numba/dispatch/slinalg.py | 30 +++++++++++++++++++++---- pytensor/tensor/slinalg.py | 13 +++++++---- tests/link/numba/test_slinalg.py | 26 +++++++++++++++++---- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 997847144b..ccae3df63a 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -9,7 +9,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import Cholesky, BlockDiagonal, SolveTriangular +from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular _PTR = ctypes.POINTER @@ -292,13 +292,14 @@ def solve_triangular(a, b): res = _solve_triangular(a, b, trans, lower, unit_diagonal) if check_finite: if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))): - raise ValueError( + raise np.linalg.LinAlgError( "Non-numeric values (nan or inf) returned by solve_triangular" ) return res return solve_triangular + def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): return linalg.cholesky( a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite @@ -339,7 +340,7 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True): INFO, ) - return A_copy + return A_copy, int_ptr_to_val(INFO) return impl @@ -348,15 +349,36 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True): def numba_funcify_Cholesky(op, node, **kwargs): lower = op.lower overwrite_a = False + check_finite = op.check_finite on_error = op.on_error @numba_basic.numba_njit(inline="always") def nb_cholesky(a): - res = _cholesky(a, lower, overwrite_a, on_error) + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) found in input to cholesky" + ) + res, info = _cholesky(a, lower, overwrite_a, check_finite) + + if on_error == "raise": + if info > 0: + raise np.linalg.LinAlgError( + "Input to cholesky is not positive definite" + ) + if info < 0: + raise ValueError( + 'LAPACK reported an illegal value in input on entry to "POTRF."' + ) + else: + if info != 0: + res = np.full_like(res, np.nan) + return res return nb_cholesky + @numba_funcify.register(BlockDiagonal) def numba_funcify_BlockDiagonal(op, node, **kwargs): dtype = node.outputs[0].dtype diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index aae80fb578..73d5b5c540 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -51,9 +51,10 @@ class Cholesky(Op): __props__ = ("lower", "destructive", "on_error") gufunc_signature = "(m,m)->(m,m)" - def __init__(self, *, lower=True, on_error="raise"): + def __init__(self, *, lower=True, check_finite=True, on_error="raise"): self.lower = lower self.destructive = False + self.check_finite = check_finite if on_error not in ("raise", "nan"): raise ValueError('on_error must be one of "raise" or ""nan"') self.on_error = on_error @@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs): x = inputs[0] z = outputs[0] try: - z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) + z[0] = scipy.linalg.cholesky( + x, lower=self.lower, check_finite=self.check_finite + ).astype(x.dtype) except scipy.linalg.LinAlgError: if self.on_error == "raise": raise @@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner): return [grad] -def cholesky(x, lower=True, on_error="raise"): - return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) +def cholesky(x, lower=True, on_error="raise", check_finite=True): + return Blockwise( + Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) + )(x) class SolveBase(Op): diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 6384b1a31a..1b8e06bbef 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -102,7 +102,8 @@ def test_solve_triangular_raises_on_nan_inf(value): b = np.full((5, 1), value) with pytest.raises( - ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") + np.linalg.LinAlgError, + match=re.escape("Non-numeric values (nan or inf) returned "), ): f(A_tri, b) @@ -127,19 +128,36 @@ def test_numba_Cholesky(lower): ) -def test_numba_Cholesky_raises_on_nan(): +def test_numba_Cholesky_raises_on_nan_input(): test_value = rng.random(size=(3, 3)).astype(config.floatX) test_value[0, 0] = np.nan x = pt.tensor(dtype=config.floatX, shape=(3, 3)) x = x.T.dot(x) - g = pt.linalg.cholesky(x, on_error="raise") + g = pt.linalg.cholesky(x, check_finite=True) f = pytensor.function([x], g, mode="NUMBA") - with pytest.raises(ValueError, match=r"Non-numeric values"): + with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): f(test_value) +@pytest.mark.parametrize("on_error", ["nan", "raise"]) +def test_numba_Cholesky_raise_on(on_error): + test_value = rng.random(size=(3, 3)).astype(config.floatX) + + x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + g = pt.linalg.cholesky(x, on_error=on_error) + f = pytensor.function([x], g, mode="NUMBA") + + if on_error == "raise": + with pytest.raises( + np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite" + ): + f(test_value) + else: + assert np.all(np.isnan(f(test_value))) + + def test_block_diag(): A = pt.matrix("A") B = pt.matrix("B") From 5aaa1fc66110d893e6dee07d78b78119854f6dcb Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 11 Jan 2024 00:06:08 +0100 Subject: [PATCH 5/8] Raise on complex inputs --- pytensor/link/numba/dispatch/slinalg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index ccae3df63a..05dff8fca1 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -311,6 +311,10 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ensure_lapack() _check_scipy_linalg_matrix(A, "cholesky") dtype = A.dtype + if str(dtype).startswith("complex"): + raise ValueError( + "Complex inputs not currently supported by cholesky in Numba mode" + ) w_type = _get_underlying_float(dtype) numba_potrf = _LAPACK().numba_xpotrf(dtype) From a0cbdc1619142f9c950e4da1fa303f99f59d0f54 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 11 Jan 2024 00:25:40 +0100 Subject: [PATCH 6/8] Change `cholesky` default for `check_finite` to `False` --- pytensor/tensor/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 73d5b5c540..963ebe3821 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -132,7 +132,7 @@ def conjugate_solve_triangular(outer, inner): return [grad] -def cholesky(x, lower=True, on_error="raise", check_finite=True): +def cholesky(x, lower=True, on_error="raise", check_finite=False): return Blockwise( Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) )(x) From 8b4b0941fb65d98e6daf91b49bb0044f59163d4d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 14 Jan 2024 13:21:06 +0100 Subject: [PATCH 7/8] Remove redundant dtype checks from numba linalg dispatchers --- pytensor/link/numba/dispatch/slinalg.py | 44 +++++++++++++++---------- tests/link/numba/test_slinalg.py | 2 +- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 05dff8fca1..f950768d54 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -215,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): _check_scipy_linalg_matrix(A, "solve_triangular") _check_scipy_linalg_matrix(B, "solve_triangular") - dtype = A.dtype - if str(dtype).startswith("complex"): - raise ValueError( - "Complex inputs not currently supported by solve_triangular in Numba mode" - ) - w_type = _get_underlying_float(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype) @@ -274,8 +268,8 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): ) if B_is_1d: - return B_copy[..., 0] - return B_copy + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) return impl @@ -287,14 +281,29 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): unit_diagonal = op.unit_diagonal check_finite = op.check_finite + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by solve_triangular in Numba mode" + ) + @numba_basic.numba_njit(inline="always") def solve_triangular(a, b): - res = _solve_triangular(a, b, trans, lower, unit_diagonal) if check_finite: - if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))): + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to solve_triangular" + ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): raise np.linalg.LinAlgError( - "Non-numeric values (nan or inf) returned by solve_triangular" + "Non-numeric values (nan or inf) in input b to solve_triangular" ) + + res, info = _solve_triangular(a, b, trans, lower, unit_diagonal) + if info != 0: + raise np.linalg.LinAlgError( + "Singular matrix in input A to solve_triangular" + ) return res return solve_triangular @@ -311,10 +320,6 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): ensure_lapack() _check_scipy_linalg_matrix(A, "cholesky") dtype = A.dtype - if str(dtype).startswith("complex"): - raise ValueError( - "Complex inputs not currently supported by cholesky in Numba mode" - ) w_type = _get_underlying_float(dtype) numba_potrf = _LAPACK().numba_xpotrf(dtype) @@ -323,9 +328,6 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True): if A.shape[-2] != _N: raise linalg.LinAlgError("Last 2 dimensions of A must be square") - if check_finite: - _check_finite_matrix(A, "cholesky") - UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) N = val_to_int_ptr(_N) LDA = val_to_int_ptr(_N) @@ -356,6 +358,12 @@ def numba_funcify_Cholesky(op, node, **kwargs): check_finite = op.check_finite on_error = op.on_error + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by cholesky in Numba mode" + ) + @numba_basic.numba_njit(inline="always") def nb_cholesky(a): if check_finite: diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 1b8e06bbef..ada495b79b 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -103,7 +103,7 @@ def test_solve_triangular_raises_on_nan_inf(value): with pytest.raises( np.linalg.LinAlgError, - match=re.escape("Non-numeric values (nan or inf) returned "), + match=re.escape("Non-numeric values"), ): f(A_tri, b) From 614fb1a6d3d0b765a6c197dc14cec8e0e547deb5 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 14 Jan 2024 21:10:08 +0100 Subject: [PATCH 8/8] Add docstring to `numba_funcify_Cholesky` explaining why the overload is necessary. --- pytensor/link/numba/dispatch/slinalg.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index f950768d54..7051fa2117 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -353,6 +353,12 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True): @numba_funcify.register(Cholesky) def numba_funcify_Cholesky(op, node, **kwargs): + """ + Overload scipy.linalg.cholesky with a numba function. + + Note that np.linalg.cholesky is already implemented in numba, but it does not support additional keyword arguments. + In particular, the `inplace` argument is not supported, which is why we choose to implement our own version. + """ lower = op.lower overwrite_a = False check_finite = op.check_finite