Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement numba overload for POTRF, LAPACK cholesky routine #578

Merged
merged 9 commits into from
Jan 15, 2024
37 changes: 1 addition & 36 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -809,41 +809,6 @@ def softplus(x):
return softplus


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower

out_dtype = node.outputs[0].type.numpy_dtype

if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba_njit
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)

else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.

warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)

@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret

return cholesky


@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
Expand Down
112 changes: 110 additions & 2 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular


_PTR = ctypes.POINTER
Expand All @@ -25,6 +25,15 @@
_ptr_int = _PTR(_int)


@numba.core.extending.register_jitable
def _check_finite_matrix(a, func_name):
for v in np.nditer(a):
if not np.isfinite(v.item()):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input to " + func_name
)


@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
Expand Down Expand Up @@ -177,6 +186,22 @@ def numba_xtrtrs(cls, dtype):

return functype(lapack_ptr)

@classmethod
def numba_xpotrf(cls, dtype):
"""
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)


def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
return linalg.solve_triangular(
Expand Down Expand Up @@ -267,14 +292,97 @@ def solve_triangular(a, b):
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
if check_finite:
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
raise ValueError(
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) returned by solve_triangular"
)
return res

return solve_triangular


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
)


@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
if str(dtype).startswith("complex"):
raise ValueError(
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
"Complex inputs not currently supported by cholesky in Numba mode"
)
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

if check_finite:
_check_finite_matrix(A, "cholesky")

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A

numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)

return A_copy, int_ptr_to_val(INFO)

return impl


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
lower = op.lower
overwrite_a = False
check_finite = op.check_finite
on_error = op.on_error

@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res, info = _cholesky(a, lower, overwrite_a, check_finite)

if on_error == "raise":
if info > 0:
raise np.linalg.LinAlgError(
"Input to cholesky is not positive definite"
)
if info < 0:
raise ValueError(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else:
if info != 0:
res = np.full_like(res, np.nan)

return res

return nb_cholesky


@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
Expand Down
13 changes: 9 additions & 4 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class Cholesky(Op):
__props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"

def __init__(self, *, lower=True, on_error="raise"):
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
self.lower = lower
self.destructive = False
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
Expand All @@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs):
x = inputs[0]
z = outputs[0]
try:
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
z[0] = scipy.linalg.cholesky(
x, lower=self.lower, check_finite=self.check_finite
).astype(x.dtype)
except scipy.linalg.LinAlgError:
if self.on_error == "raise":
raise
Expand Down Expand Up @@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner):
return [grad]


def cholesky(x, lower=True, on_error="raise"):
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
def cholesky(x, lower=True, on_error="raise", check_finite=True):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)


class SolveBase(Op):
Expand Down
51 changes: 0 additions & 51 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,6 @@
rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower=lower)(x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


@pytest.mark.parametrize(
"A, x, lower, exc",
[
Expand Down
56 changes: 55 additions & 1 deletion tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, FunctionGraph
from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.test_extra_ops import set_test_value


numba = pytest.importorskip("numba")
Expand Down Expand Up @@ -99,11 +102,62 @@ def test_solve_triangular_raises_on_nan_inf(value):
b = np.full((5, 1), value)

with pytest.raises(
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
np.linalg.LinAlgError,
match=re.escape("Non-numeric values (nan or inf) returned "),
):
f(A_tri, b)


@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky(lower):
x = set_test_value(
pt.tensor(dtype=config.floatX, shape=(3, 3)),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
)

g = pt.linalg.cholesky(x, lower=lower)
g_fg = FunctionGraph(outputs=[g])

compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


def test_numba_Cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(config.floatX)
test_value[0, 0] = np.nan

x = pt.tensor(dtype=config.floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")

with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)


@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_numba_Cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(config.floatX)

x = pt.tensor(dtype=config.floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")

if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))


def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
Expand Down
Loading