diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 9c9c800b92..8ac3f5680d 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -37,7 +37,7 @@ from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -from pytensor.tensor.slinalg import Cholesky, Solve +from pytensor.tensor.slinalg import Solve from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -809,41 +809,6 @@ def softplus(x): return softplus -@numba_funcify.register(Cholesky) -def numba_funcify_Cholesky(op, node, **kwargs): - lower = op.lower - - out_dtype = node.outputs[0].type.numpy_dtype - - if lower: - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_njit - def cholesky(a): - return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype) - - else: - # TODO: Use SciPy's BLAS/LAPACK Cython wrappers. - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`lower` argument to `scipy.linalg.cholesky`." - ), - UserWarning, - ) - - ret_sig = get_numba_type(node.outputs[0].type) - - @numba_njit - def cholesky(a): - with numba.objmode(ret=ret_sig): - ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype) - return ret - - return cholesky - - @numba_funcify.register(Solve) def numba_funcify_Solve(op, node, **kwargs): assume_a = op.assume_a diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index a5ac0c6348..7051fa2117 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -9,7 +9,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular +from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular _PTR = ctypes.POINTER @@ -25,6 +25,15 @@ _ptr_int = _PTR(_int) +@numba.core.extending.register_jitable +def _check_finite_matrix(a, func_name): + for v in np.nditer(a): + if not np.isfinite(v.item()): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input to " + func_name + ) + + @intrinsic def val_to_dptr(typingctx, data): def impl(context, builder, signature, args): @@ -177,6 +186,22 @@ def numba_xtrtrs(cls, dtype): return functype(lapack_ptr) + @classmethod + def numba_xpotrf(cls, dtype): + """ + Called by scipy.linalg.cholesky + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO, + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # INFO + ) + return functype(lapack_ptr) + def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): return linalg.solve_triangular( @@ -190,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): _check_scipy_linalg_matrix(A, "solve_triangular") _check_scipy_linalg_matrix(B, "solve_triangular") - dtype = A.dtype - if str(dtype).startswith("complex"): - raise ValueError( - "Complex inputs not currently supported by solve_triangular in Numba mode" - ) - w_type = _get_underlying_float(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype) @@ -249,8 +268,8 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): ) if B_is_1d: - return B_copy[..., 0] - return B_copy + return B_copy[..., 0], int_ptr_to_val(INFO) + return B_copy, int_ptr_to_val(INFO) return impl @@ -262,19 +281,122 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): unit_diagonal = op.unit_diagonal check_finite = op.check_finite + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by solve_triangular in Numba mode" + ) + @numba_basic.numba_njit(inline="always") def solve_triangular(a, b): - res = _solve_triangular(a, b, trans, lower, unit_diagonal) if check_finite: - if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))): - raise ValueError( - "Non-numeric values (nan or inf) returned by solve_triangular" + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input A to solve_triangular" ) + if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) in input b to solve_triangular" + ) + + res, info = _solve_triangular(a, b, trans, lower, unit_diagonal) + if info != 0: + raise np.linalg.LinAlgError( + "Singular matrix in input A to solve_triangular" + ) return res return solve_triangular +def _cholesky(a, lower=False, overwrite_a=False, check_finite=True): + return linalg.cholesky( + a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite + ) + + +@overload(_cholesky) +def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True): + ensure_lapack() + _check_scipy_linalg_matrix(A, "cholesky") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_potrf = _LAPACK().numba_xpotrf(dtype) + + def impl(A, lower=0, overwrite_a=False, check_finite=True): + _N = np.int32(A.shape[-1]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + if not overwrite_a: + A_copy = _copy_to_fortran_order(A) + else: + A_copy = A + + numba_potrf( + UPLO, + N, + A_copy.view(w_type).ctypes, + LDA, + INFO, + ) + + return A_copy, int_ptr_to_val(INFO) + + return impl + + +@numba_funcify.register(Cholesky) +def numba_funcify_Cholesky(op, node, **kwargs): + """ + Overload scipy.linalg.cholesky with a numba function. + + Note that np.linalg.cholesky is already implemented in numba, but it does not support additional keyword arguments. + In particular, the `inplace` argument is not supported, which is why we choose to implement our own version. + """ + lower = op.lower + overwrite_a = False + check_finite = op.check_finite + on_error = op.on_error + + dtype = node.inputs[0].dtype + if str(dtype).startswith("complex"): + raise NotImplementedError( + "Complex inputs not currently supported by cholesky in Numba mode" + ) + + @numba_basic.numba_njit(inline="always") + def nb_cholesky(a): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) found in input to cholesky" + ) + res, info = _cholesky(a, lower, overwrite_a, check_finite) + + if on_error == "raise": + if info > 0: + raise np.linalg.LinAlgError( + "Input to cholesky is not positive definite" + ) + if info < 0: + raise ValueError( + 'LAPACK reported an illegal value in input on entry to "POTRF."' + ) + else: + if info != 0: + res = np.full_like(res, np.nan) + + return res + + return nb_cholesky + + @numba_funcify.register(BlockDiagonal) def numba_funcify_BlockDiagonal(op, node, **kwargs): dtype = node.outputs[0].dtype diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index aae80fb578..963ebe3821 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -51,9 +51,10 @@ class Cholesky(Op): __props__ = ("lower", "destructive", "on_error") gufunc_signature = "(m,m)->(m,m)" - def __init__(self, *, lower=True, on_error="raise"): + def __init__(self, *, lower=True, check_finite=True, on_error="raise"): self.lower = lower self.destructive = False + self.check_finite = check_finite if on_error not in ("raise", "nan"): raise ValueError('on_error must be one of "raise" or ""nan"') self.on_error = on_error @@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs): x = inputs[0] z = outputs[0] try: - z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) + z[0] = scipy.linalg.cholesky( + x, lower=self.lower, check_finite=self.check_finite + ).astype(x.dtype) except scipy.linalg.LinAlgError: if self.on_error == "raise": raise @@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner): return [grad] -def cholesky(x, lower=True, on_error="raise"): - return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) +def cholesky(x, lower=True, on_error="raise", check_finite=False): + return Blockwise( + Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) + )(x) class SolveBase(Op): diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 4732a8f3d0..5e25bbf53d 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -14,57 +14,6 @@ rng = np.random.default_rng(42849) -@pytest.mark.parametrize( - "x, lower, exc", - [ - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - True, - None, - ), - ( - set_test_value( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - True, - None, - ), - ( - set_test_value( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - False, - UserWarning, - ), - ], -) -def test_Cholesky(x, lower, exc): - g = slinalg.Cholesky(lower=lower)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - @pytest.mark.parametrize( "A, x, lower, exc", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 33ec1a529c..ada495b79b 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -6,7 +6,10 @@ import pytensor import pytensor.tensor as pt from pytensor import config +from pytensor.compile import SharedVariable +from pytensor.graph import Constant, FunctionGraph from tests.link.numba.test_basic import compare_numba_and_py +from tests.tensor.test_extra_ops import set_test_value numba = pytest.importorskip("numba") @@ -99,11 +102,62 @@ def test_solve_triangular_raises_on_nan_inf(value): b = np.full((5, 1), value) with pytest.raises( - ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") + np.linalg.LinAlgError, + match=re.escape("Non-numeric values"), ): f(A_tri, b) +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +def test_numba_Cholesky(lower): + x = set_test_value( + pt.tensor(dtype=config.floatX, shape=(3, 3)), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)), + ) + + g = pt.linalg.cholesky(x, lower=lower) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_numba_Cholesky_raises_on_nan_input(): + test_value = rng.random(size=(3, 3)).astype(config.floatX) + test_value[0, 0] = np.nan + + x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + x = x.T.dot(x) + g = pt.linalg.cholesky(x, check_finite=True) + f = pytensor.function([x], g, mode="NUMBA") + + with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"): + f(test_value) + + +@pytest.mark.parametrize("on_error", ["nan", "raise"]) +def test_numba_Cholesky_raise_on(on_error): + test_value = rng.random(size=(3, 3)).astype(config.floatX) + + x = pt.tensor(dtype=config.floatX, shape=(3, 3)) + g = pt.linalg.cholesky(x, on_error=on_error) + f = pytensor.function([x], g, mode="NUMBA") + + if on_error == "raise": + with pytest.raises( + np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite" + ): + f(test_value) + else: + assert np.all(np.isnan(f(test_value))) + + def test_block_diag(): A = pt.matrix("A") B = pt.matrix("B")