diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 64bb3017f780..6c3c128cff99 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -310,7 +310,23 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]: return lu, pivots, permutation -def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]: +@overload +def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True, + ) -> tuple[Array, Array]: + ... + +@overload +def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True, + ) -> tuple[Array, Array, Array]: + ... + +@overload +def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True, + ) -> tuple[Array, Array] | tuple[Array, Array, Array]: + ... + +def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True, + ) -> tuple[Array, Array] | tuple[Array, Array, Array]: """QR decomposition. Computes the QR decomposition @@ -323,11 +339,15 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]: Args: x: A batch of matrices with shape ``[..., m, n]``. + pivoting: Allows the QR decomposition to be rank-revealing. If ``True``, + compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P`` + is chosen such that the diagonal of ``R`` is non-increasing. Currently + supported on CPU backends only. full_matrices: Determines if full or reduced matrices are returned; see below. Returns: - A pair of arrays ``(q, r)``. + A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``. Array ``q`` is a unitary (orthogonal) matrix, with shape ``[..., m, m]`` if ``full_matrices=True``, or @@ -336,8 +356,12 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]: Array ``r`` is an upper-triangular matrix with shape ``[..., m, n]`` if ``full_matrices=True``, or ``[..., min(m, n), n]`` if ``full_matrices=False``. + + Array ``p`` is an index vector with shape [..., n] """ - q, r = qr_p.bind(x, full_matrices=full_matrices) + q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices) + if pivoting: + return q, r, p[0] return q, r @@ -1846,6 +1870,61 @@ def _geqrf_cpu_gpu_lowering(ctx, a, *, target_name_prefix: str): platform='rocm') +def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]: + """Computes the column-pivoted QR decomposition of a matrix. + + Args: + a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type. + jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type, + Returns: + A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``, + ``q`` is represented in the lower triangle of ``a`` and in ``taus`` as + elementary Householder reflectors, and ``jpvt`` is the column-pivot indices + such that ``a[:, jpvt] = q @ r``. + """ + a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt) + return a_out, jpvt_out, taus + +def _geqp3_abstract_eval(a, jpvt): + if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray): + raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: " + f"{a.aval}, {jpvt.aval}") + if a.ndim < 2: + raise ValueError("Argument to column-pivoted QR decomposition must have ndims >= 2") + *batch_dims, m, n = a.shape + *jpvt_batch_dims, jpvt_n = jpvt.shape + if batch_dims != jpvt_batch_dims or jpvt_n != n: + raise ValueError(f"Type mismatch for pivoted QR decomposition: {a=} {jpvt=}") + taus = a.update(shape=(*batch_dims, core.min_dim(m, n))) + return a, jpvt, taus + +def _geqp3_batching_rule(batched_args, batch_dims): + a, jpvt = batched_args + b_a, b_jpvt = batch_dims + a = batching.moveaxis(a, b_a, 0) + jpvt = batching.moveaxis(jpvt, b_jpvt, 0) + return geqp3(a, jpvt), (0, 0, 0) + +def _geqp3_cpu_lowering(ctx, a, jpvt): + a_aval, jpvt_aval = ctx.avals_in + batch_dims = a_aval.shape[:-2] + nb = len(batch_dims) + layout = [(nb, nb + 1) + tuple(range(nb - 1, -1, -1)), tuple(range(nb, -1, -1))] + result_layouts = layout + [tuple(range(nb, -1, -1))] + target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype) + rule = ffi.ffi_lowering(target_name, operand_layouts=layout, + result_layouts=result_layouts, + operand_output_aliases={0: 0, 1: 1}) + return rule(ctx, a, jpvt) + + +geqp3_p = Primitive('geqp3') +geqp3_p.multiple_results = True +geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p)) +geqp3_p.def_abstract_eval(_geqp3_abstract_eval) +batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule +mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu") + # householder_product: product of elementary Householder reflectors def householder_product(a: ArrayLike, taus: ArrayLike) -> Array: @@ -1938,11 +2017,12 @@ def _householder_product_cpu_gpu_lowering(ctx, a, taus, *, platform='rocm') -def _qr_impl(operand, *, full_matrices): - q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices) - return q, r +def _qr_impl(operand, *, pivoting, full_matrices): + q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting, + full_matrices=full_matrices) + return (q, r, p[0]) if pivoting else (q, r) -def _qr_abstract_eval(operand, *, full_matrices): +def _qr_abstract_eval(operand, *, pivoting, full_matrices): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") @@ -1950,20 +2030,24 @@ def _qr_abstract_eval(operand, *, full_matrices): k = m if full_matrices else core.min_dim(m, n) q = operand.update(shape=(*batch_dims, m, k)) r = operand.update(shape=(*batch_dims, k, n)) + p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32)) else: q = operand r = operand - return q, r + p = operand + return (q, r, p) if pivoting else (q, r) -def qr_jvp_rule(primals, tangents, *, full_matrices): +def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals dx, = tangents - q, r = qr_p.bind(x, full_matrices=False) + q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False) *_, m, n = x.shape if m < n or (full_matrices and m != n): raise NotImplementedError( "Unimplemented case of QR decomposition derivative") + if pivoting: + dx = dx[..., p[0]] dx_rinv = triangular_solve(r, dx) # Right side solve by default qt_dx_rinv = _H(q) @ dx_rinv qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) @@ -1973,15 +2057,19 @@ def qr_jvp_rule(primals, tangents, *, full_matrices): do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) dq = q @ (do - qt_dx_rinv) + dx_rinv dr = (qt_dx_rinv - do) @ r + if pivoting: + dp = ad_util.Zero.from_primal_value(p[0]) + return (q, r, p[0]), (dq, dr, dp) return (q, r), (dq, dr) -def _qr_batching_rule(batched_args, batch_dims, *, full_matrices): +def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) - return qr_p.bind(x, full_matrices=full_matrices), (0, 0) + out_axes = (0, 0, 0) if pivoting else (0, 0) + return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes -def _qr_lowering(a, *, full_matrices): +def _qr_lowering(a, *, pivoting, full_matrices): *batch_dims, m, n = a.shape if m == 0 or n == 0: k = m if full_matrices else core.min_dim(m, n) @@ -1989,9 +2077,18 @@ def _qr_lowering(a, *, full_matrices): (*batch_dims, m, k), (len(batch_dims), len(batch_dims) + 1)) r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype) + if pivoting: + p = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32)) + return q, r, p return q, r - r, taus = geqrf(a) + if pivoting: + jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32)) + r, p, taus = geqp3(a, jpvt) + p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1. + else: + r, taus = geqrf(a) + if m < n: q = householder_product(r[..., :m, :m], taus) elif full_matrices: @@ -2002,6 +2099,8 @@ def _qr_lowering(a, *, full_matrices): q = householder_product(r, taus) r = r[..., :n, :n] r = _triu(r) + if pivoting: + return q, r, p return q, r @@ -2015,7 +2114,6 @@ def _qr_lowering(a, *, full_matrices): mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering)) - # Singular value decomposition def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None, algorithm=None): diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index 77ff4297e137..9f22f130cbb2 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -197,7 +197,7 @@ def svd( reduce_to_square = False if full_matrices: - q_full, a_full = lax.linalg.qr(a, full_matrices=True) + q_full, a_full = lax.linalg.qr(a, pivoting=False, full_matrices=True) q = q_full[:, :n] u_out_null = q_full[:, n:] a = a_full[:n, :] @@ -206,7 +206,7 @@ def svd( # The constant `1.15` comes from Yuji Nakatsukasa's implementation # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav if m > 1.15 * n: - q, a = lax.linalg.qr(a, full_matrices=False) + q, a = lax.linalg.qr(a, pivoting=False, full_matrices=False) reduce_to_square = True if not compute_uv: diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index ff4e4e07e0e6..e2cd25607c01 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -1288,7 +1288,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: full_matrices = True else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") - q, r = lax_linalg.qr(a, full_matrices=full_matrices) + q, r = lax_linalg.qr(a, pivoting=False, full_matrices=full_matrices) if mode == "r": return r return QRResult(q, r) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 2e3632700759..a613e301f4fa 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -820,20 +820,39 @@ def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, del overwrite_a, check_finite # unused return _lu(a, permute_l) + +@overload +def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[False] + ) -> tuple[Array]: ... + +@overload +def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[True] + ) -> tuple[Array, Array]: ... + +@overload +def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[False] + ) -> tuple[Array, Array]: ... + +@overload +def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[True] + ) -> tuple[Array, Array, Array]: ... + @overload -def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> tuple[Array]: ... +def _qr(a: ArrayLike, mode: str, pivoting: Literal[False] + ) -> tuple[Array] | tuple[Array, Array]: ... @overload -def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> tuple[Array, Array]: ... +def _qr(a: ArrayLike, mode: str, pivoting: Literal[True] + ) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @overload -def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]: ... +def _qr(a: ArrayLike, mode: str, pivoting: bool + ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ... + @partial(jit, static_argnames=('mode', 'pivoting')) -def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]: - if pivoting: - raise NotImplementedError( - "The pivoting=True case of qr is not implemented.") +def _qr(a: ArrayLike, mode: str, pivoting: bool + ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: if mode in ("full", "r"): full_matrices = True elif mode == "economic": @@ -841,31 +860,56 @@ def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") a, = promote_dtypes_inexact(jnp.asarray(a)) - q, r = lax_linalg.qr(a, full_matrices=full_matrices) + q, r, *p = lax_linalg.qr(a, pivoting=pivoting, full_matrices=full_matrices) if mode == "r": + if pivoting: + return r, p[0] return (r,) + if pivoting: + return q, r, p[0] return q, r @overload -def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: Literal["full", "economic"] = "full", - pivoting: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ... +def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, + mode: Literal["full", "economic"], pivoting: Literal[False] = False, + check_finite: bool = True) -> tuple[Array, Array]: ... + +@overload +def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, + mode: Literal["full", "economic"], pivoting: Literal[True] = True, + check_finite: bool = True) -> tuple[Array, Array, Array]: ... @overload -def qr(a: ArrayLike, overwrite_a: bool, lwork: Any, mode: Literal["r"], - pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ... +def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, + mode: Literal["full", "economic"], pivoting: bool = False, + check_finite: bool = True + ) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... @overload -def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"], - pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ... +def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, + mode: Literal["r"], pivoting: Literal[False] = False, check_finite: bool = True + ) -> tuple[Array]: ... + +@overload +def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, + mode: Literal["r"], pivoting: Literal[True] = True, check_finite: bool = True + ) -> tuple[Array, Array]: ... + +@overload +def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, + mode: Literal["r"], pivoting: bool = False, check_finite: bool = True + ) -> tuple[Array] | tuple[Array, Array]: ... @overload def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", - pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ... + pivoting: bool = False, check_finite: bool = True + ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ... def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", - pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: + pivoting: bool = False, check_finite: bool = True + ) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: """Compute the QR decomposition of an array JAX implementation of :func:`scipy.linalg.qr`. @@ -888,22 +932,29 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = " - ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``, where K = min(M, N). - pivoting: Not implemented in JAX. + pivoting: Allows the QR decomposition to be rank-revealing. If ``True``, compute + the column-pivoted decomposition ``A[:, P] = Q @ R``, where ``P`` is chosen such + that the diagonal of ``R`` is non-increasing. overwrite_a: unused in JAX lwork: unused in JAX check_finite: unused in JAX Returns: - A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``, - where: + A tuple ``(Q, R)`` or ``(Q, R, P)``, if ``mode`` is not ``"r"`` and ``pivoting`` is + respectively ``False`` or ``True``, otherwise an array ``R`` or tuple ``(R, P)`` if + mode is ``"r"``, and ``pivoting`` is respectively ``False`` or ``True``, where: - ``Q`` is an orthogonal matrix of shape ``(..., M, M)`` (if ``mode`` is ``"full"``) - or ``(..., M, K)`` (if ``mode`` is ``"economic"``). + or ``(..., M, K)`` (if ``mode`` is ``"economic"``), - ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is - ``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``) + ``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``), + - ``P`` is an index vector of shape ``(..., N)``. with ``K = min(M, N)``. + Notes: + - At present, pivoting is only implemented on CPU backends. + See also: - :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 6fb3a6744160..317cd73134c3 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1522,6 +1522,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "lu_pivots_to_permutation", "xla_pmap", "geqrf", + "geqp3", "householder_product", "hessenberg", "tridiagonal", diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 7924e3980dcc..0d27f3d2f041 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -129,6 +129,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgeqrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgeqrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgeqrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgeqrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_sgeqp3_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgeqp3_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgeqp3_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgeqp3_ffi); JAX_CPU_REGISTER_HANDLER(lapack_sorgqr_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dorgqr_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cungqr_ffi); diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index a22a8f6d1cd3..c50fc5ff4d86 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -82,6 +82,11 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("cgeqrf")); AssignKernelFn>(lapack_ptr("zgeqrf")); + AssignKernelFn>(lapack_ptr("sgeqp3")); + AssignKernelFn>(lapack_ptr("dgeqp3")); + AssignKernelFn>(lapack_ptr("cgeqp3")); + AssignKernelFn>(lapack_ptr("zgeqp3")); + AssignKernelFn>(lapack_ptr("sorgqr")); AssignKernelFn>(lapack_ptr("dorgqr")); AssignKernelFn>>(lapack_ptr("cungqr")); @@ -246,6 +251,10 @@ nb::dict Registrations() { dict["lapack_dgeqrf_ffi"] = EncapsulateFunction(lapack_dgeqrf_ffi); dict["lapack_cgeqrf_ffi"] = EncapsulateFunction(lapack_cgeqrf_ffi); dict["lapack_zgeqrf_ffi"] = EncapsulateFunction(lapack_zgeqrf_ffi); + dict["lapack_sgeqp3_ffi"] = EncapsulateFunction(lapack_sgeqp3_ffi); + dict["lapack_dgeqp3_ffi"] = EncapsulateFunction(lapack_dgeqp3_ffi); + dict["lapack_cgeqp3_ffi"] = EncapsulateFunction(lapack_cgeqp3_ffi); + dict["lapack_zgeqp3_ffi"] = EncapsulateFunction(lapack_zgeqp3_ffi); dict["lapack_sorgqr_ffi"] = EncapsulateFunction(lapack_sorgqr_ffi); dict["lapack_dorgqr_ffi"] = EncapsulateFunction(lapack_dorgqr_ffi); dict["lapack_cungqr_ffi"] = EncapsulateFunction(lapack_cungqr_ffi); diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 31c180581946..b0e2935a88e8 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -354,6 +354,76 @@ template struct QrFactorization; template struct QrFactorization; template struct QrFactorization; +//== Column Pivoting QR Factorization ==// + +// lapack geqp3 +template +ffi::Error PivotingQrFactorization::Kernel( + ffi::Buffer x, ffi::Buffer jpvt, + ffi::ResultBuffer x_out, ffi::ResultBuffer jpvt_out, + ffi::ResultBuffer tau) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); + auto* x_out_data = x_out->typed_data(); + auto* jpvt_out_data = jpvt_out->typed_data(); + auto* tau_data = tau->typed_data(); + lapack_int info; + const int64_t work_size = GetWorkspaceSize(x_rows, x_cols); + auto work_data = AllocateScratchMemory(work_size); + constexpr bool is_complex_dtype = ffi::IsComplexType(); + std::unique_ptr rwork_data; + if constexpr (is_complex_dtype) { + rwork_data = AllocateScratchMemory(2 * x_cols); + } + + CopyIfDiffBuffer(x, x_out); + CopyIfDiffBuffer(jpvt, jpvt_out); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, + MaybeCastNoOverflow(work_size)); + FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + auto x_leading_dim_v = x_rows_v; + + const int64_t x_out_step{x_rows * x_cols}; + const int64_t jpvt_step{x_cols}; + const int64_t tau_step{std::min(x_rows, x_cols)}; + for (int64_t i = 0; i < batch_count; ++i) { + if constexpr (is_complex_dtype) { + fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, jpvt_out_data, + tau_data, work_data.get(), &workspace_dim_v, rwork_data.get(), &info); + } else { + fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, jpvt_out_data, + tau_data, work_data.get(), &workspace_dim_v, &info); + } + x_out_data += x_out_step; + jpvt_out_data += jpvt_step; + tau_data += tau_step; + } + return ffi::Error::Success(); +} + +template +int64_t PivotingQrFactorization::GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols) { + ValueType optimal_size{}; + lapack_int x_leading_dim_v = x_rows; + lapack_int info = 0; + lapack_int workspace_query = -1; + if constexpr (ffi::IsComplexType()) { + fn(&x_rows, &x_cols, nullptr, &x_leading_dim_v, nullptr, nullptr, + &optimal_size, &workspace_query, nullptr, &info); + } else { + fn(&x_rows, &x_cols, nullptr, &x_leading_dim_v, nullptr, nullptr, + &optimal_size, &workspace_query, &info); + } + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +} + +template struct PivotingQrFactorization; +template struct PivotingQrFactorization; +template struct PivotingQrFactorization; +template struct PivotingQrFactorization; + //== Orthogonal QR ==// //== Computes orthogonal matrix Q from QR Decomposition ==// @@ -2012,6 +2082,16 @@ template struct TridiagonalReduction; .Ret<::xla::ffi::Buffer>(/*x_out*/) \ .Ret<::xla::ffi::Buffer>(/*tau*/)) +#define JAX_CPU_DEFINE_GEQP3(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, PivotingQrFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*jpvt*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*jpvt_out*/) \ + .Ret<::xla::ffi::Buffer>(/*tau*/)) + #define JAX_CPU_DEFINE_ORGQR(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ name, OrthogonalQr::Kernel, \ @@ -2172,6 +2252,11 @@ JAX_CPU_DEFINE_GEQRF(lapack_dgeqrf_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GEQRF(lapack_cgeqrf_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_GEQRF(lapack_zgeqrf_ffi, ::xla::ffi::DataType::C128); +JAX_CPU_DEFINE_GEQP3(lapack_sgeqp3_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GEQP3(lapack_dgeqp3_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GEQP3(lapack_cgeqp3_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GEQP3(lapack_zgeqp3_ffi, ::xla::ffi::DataType::C128); + JAX_CPU_DEFINE_ORGQR(lapack_sorgqr_ffi, ::xla::ffi::DataType::F32); JAX_CPU_DEFINE_ORGQR(lapack_dorgqr_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_ORGQR(lapack_cungqr_ffi, ::xla::ffi::DataType::C64); @@ -2215,6 +2300,7 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128); #undef JAX_CPU_DEFINE_TRSM #undef JAX_CPU_DEFINE_GETRF #undef JAX_CPU_DEFINE_GEQRF +#undef JAX_CPU_DEFINE_GEQP3 #undef JAX_CPU_DEFINE_ORGQR #undef JAX_CPU_DEFINE_POTRF #undef JAX_CPU_DEFINE_GESDD diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 8c24ddfdd9f8..6194cbde5b34 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -213,6 +213,34 @@ struct QrFactorization { static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; +//== Column Pivoting QR Factorization ==// + +// lapack geqp3 +template <::xla::ffi::DataType dtype> +struct PivotingQrFactorization { + using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>; + using ValueType = ::xla::ffi::NativeType; + using FnType = std::conditional_t< + ::xla::ffi::IsComplexType(), + void(lapack_int* m, lapack_int* n, ValueType* a, lapack_int* lda, + lapack_int* jpvt, ValueType* tau, ValueType* work, lapack_int* lwork, + RealType* rwork, lapack_int* info), + void(lapack_int* m, lapack_int* n, ValueType* a, lapack_int* lda, + lapack_int* jpvt, ValueType* tau, ValueType* work, lapack_int* lwork, + lapack_int* info)>; + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::Buffer jpvt, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer jpvt_out, + ::xla::ffi::ResultBuffer tau); + + static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); +}; + + //== Orthogonal QR ==// // lapack orgqr @@ -466,9 +494,9 @@ struct EigenvalueDecompositionHermitian { // LAPACK uses a packed representation to represent a mixture of real // eigenvectors and complex conjugate pairs. This helper unpacks the // representation into regular complex matrices. -template -static void UnpackEigenvectors(Int n, const T* eigenvals_imag, - const T* packed, std::complex* unpacked) { +template +static void UnpackEigenvectors(Int n, const T* eigenvals_imag, const T* packed, + std::complex* unpacked) { for (int j = 0; j < n;) { if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { // Real values in each row without imaginary part @@ -753,6 +781,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeqrf_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeqrf_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeqrf_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeqrf_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeqp3_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeqp3_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeqp3_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeqp3_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sorgqr_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dorgqr_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cungqr_ffi); diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index ad64069a2499..af4947b12e9c 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -41,6 +41,11 @@ jax::QrFactorization::FnType dgeqrf_; jax::QrFactorization::FnType cgeqrf_; jax::QrFactorization::FnType zgeqrf_; +jax::PivotingQrFactorization::FnType sgeqp3_; +jax::PivotingQrFactorization::FnType dgeqp3_; +jax::PivotingQrFactorization::FnType cgeqp3_; +jax::PivotingQrFactorization::FnType zgeqp3_; + jax::OrthogonalQr::FnType sorgqr_; jax::OrthogonalQr::FnType dorgqr_; jax::OrthogonalQr::FnType cungqr_; @@ -335,6 +340,11 @@ static auto init = []() -> int { AssignKernelFn>(cgeqrf_); AssignKernelFn>(zgeqrf_); + AssignKernelFn>(sgeqp3_); + AssignKernelFn>(dgeqp3_); + AssignKernelFn>(cgeqp3_); + AssignKernelFn>(zgeqp3_); + AssignKernelFn>(sorgqr_); AssignKernelFn>(dorgqr_); AssignKernelFn>(cungqr_); diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 65f7c8145138..b1b68ecc8662 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1712,17 +1712,36 @@ def osp_fun(A): @jtu.sample_product( # Skip empty shapes because scipy fails: https://github.com/scipy/scipy/issues/1532 shape=[(3, 4), (3, 3), (4, 3)], - dtype=[np.float32], + dtype=float_types + complex_types, mode=["full", "r", "economic"], + pivoting=[False, True] ) - def testScipyQrModes(self, shape, dtype, mode): + def testScipyQrModes(self, shape, dtype, mode, pivoting): + is_not_cpu_test_device = not jtu.test_device_matches(["cpu"]) + is_not_valid_jaxlib_version = jtu.jaxlib_version() <= (0, 4, 38) + if pivoting and (is_not_cpu_test_device or is_not_valid_jaxlib_version): + self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38") rng = jtu.rand_default(self.rng()) - jsp_func = partial(jax.scipy.linalg.qr, mode=mode) - sp_func = partial(scipy.linalg.qr, mode=mode) + jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting) + sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5) self._CompileAndCheck(jsp_func, args_maker) + # Pivoting is unsupported by the numpy api - repeat the jvp checks performed + # in NumpyLinalgTest::testQR for the `pivoting=True` modes here. Like in the + # numpy test, `qr_and_mul` expresses the identity function. + def qr_and_mul(a): + q, r, *p = jsp_func(a) + # To express the identity function we must "undo" the pivoting of `q @ r`. + inverted_pivots = p[0][p[0]] + return (q @ r)[:, inverted_pivots] + + m, n = shape + if pivoting and mode != "r" and (m == n or (m > n and mode != "full")): + for a in args_maker(): + jtu.check_jvp(qr_and_mul, partial(jvp, qr_and_mul), (a,), atol=3e-3) + @jtu.sample_product( [dict(shape=shape, k=k) for shape in [(1, 1), (3, 4, 4), (10, 5)]