Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement column-pivoted QR via geqp3 (CPU lowering only) #20282

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 113 additions & 15 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,23 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
return lu, pivots, permutation


def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
@overload
def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True,
) -> tuple[Array, Array]:
...

@overload
def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True,
) -> tuple[Array, Array, Array]:
...

@overload
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
...

def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
"""QR decomposition.

Computes the QR decomposition
Expand All @@ -323,11 +339,15 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:

Args:
x: A batch of matrices with shape ``[..., m, n]``.
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``,
compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P``
is chosen such that the diagonal of ``R`` is non-increasing. Currently
supported on CPU backends only.
full_matrices: Determines if full or reduced matrices are returned; see
below.

Returns:
A pair of arrays ``(q, r)``.
A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``.

Array ``q`` is a unitary (orthogonal) matrix,
with shape ``[..., m, m]`` if ``full_matrices=True``, or
Expand All @@ -336,8 +356,12 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
Array ``r`` is an upper-triangular matrix with shape ``[..., m, n]`` if
``full_matrices=True``, or ``[..., min(m, n), n]`` if
``full_matrices=False``.

Array ``p`` is an index vector with shape [..., n]
"""
q, r = qr_p.bind(x, full_matrices=full_matrices)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices)
if pivoting:
return q, r, p[0]
return q, r


Expand Down Expand Up @@ -1846,6 +1870,61 @@ def _geqrf_cpu_gpu_lowering(ctx, a, *, target_name_prefix: str):
platform='rocm')


def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]:
"""Computes the column-pivoted QR decomposition of a matrix.

Args:
a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type.
jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type,
Returns:
A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``,
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
such that ``a[:, jpvt] = q @ r``.
"""
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt)
return a_out, jpvt_out, taus

def _geqp3_abstract_eval(a, jpvt):
if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray):
raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: "
f"{a.aval}, {jpvt.aval}")
if a.ndim < 2:
raise ValueError("Argument to column-pivoted QR decomposition must have ndims >= 2")
*batch_dims, m, n = a.shape
*jpvt_batch_dims, jpvt_n = jpvt.shape
if batch_dims != jpvt_batch_dims or jpvt_n != n:
raise ValueError(f"Type mismatch for pivoted QR decomposition: {a=} {jpvt=}")
taus = a.update(shape=(*batch_dims, core.min_dim(m, n)))
return a, jpvt, taus

def _geqp3_batching_rule(batched_args, batch_dims):
a, jpvt = batched_args
b_a, b_jpvt = batch_dims
a = batching.moveaxis(a, b_a, 0)
jpvt = batching.moveaxis(jpvt, b_jpvt, 0)
return geqp3(a, jpvt), (0, 0, 0)

def _geqp3_cpu_lowering(ctx, a, jpvt):
a_aval, jpvt_aval = ctx.avals_in
batch_dims = a_aval.shape[:-2]
nb = len(batch_dims)
layout = [(nb, nb + 1) + tuple(range(nb - 1, -1, -1)), tuple(range(nb, -1, -1))]
result_layouts = layout + [tuple(range(nb, -1, -1))]
target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype)
rule = ffi.ffi_lowering(target_name, operand_layouts=layout,
result_layouts=result_layouts,
operand_output_aliases={0: 0, 1: 1})
return rule(ctx, a, jpvt)


geqp3_p = Primitive('geqp3')
geqp3_p.multiple_results = True
geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p))
geqp3_p.def_abstract_eval(_geqp3_abstract_eval)
batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule
mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu")

# householder_product: product of elementary Householder reflectors

def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
Expand Down Expand Up @@ -1938,32 +2017,37 @@ def _householder_product_cpu_gpu_lowering(ctx, a, taus, *,
platform='rocm')


def _qr_impl(operand, *, full_matrices):
q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r
def _qr_impl(operand, *, pivoting, full_matrices):
q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting,
full_matrices=full_matrices)
return (q, r, p[0]) if pivoting else (q, r)

def _qr_abstract_eval(operand, *, full_matrices):
def _qr_abstract_eval(operand, *, pivoting, full_matrices):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
k = m if full_matrices else core.min_dim(m, n)
q = operand.update(shape=(*batch_dims, m, k))
r = operand.update(shape=(*batch_dims, k, n))
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32))
else:
q = operand
r = operand
return q, r
p = operand
return (q, r, p) if pivoting else (q, r)

def qr_jvp_rule(primals, tangents, *, full_matrices):
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
x, = primals
dx, = tangents
q, r = qr_p.bind(x, full_matrices=False)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False)
*_, m, n = x.shape
if m < n or (full_matrices and m != n):
raise NotImplementedError(
"Unimplemented case of QR decomposition derivative")
if pivoting:
dx = dx[..., p[0]]
dx_rinv = triangular_solve(r, dx) # Right side solve by default
qt_dx_rinv = _H(q) @ dx_rinv
qt_dx_rinv_lower = _tril(qt_dx_rinv, -1)
Expand All @@ -1973,25 +2057,38 @@ def qr_jvp_rule(primals, tangents, *, full_matrices):
do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype))
dq = q @ (do - qt_dx_rinv) + dx_rinv
dr = (qt_dx_rinv - do) @ r
if pivoting:
dp = ad_util.Zero.from_primal_value(p[0])
return (q, r, p[0]), (dq, dr, dp)
return (q, r), (dq, dr)

def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
out_axes = (0, 0, 0) if pivoting else (0, 0)
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes

def _qr_lowering(a, *, full_matrices):
def _qr_lowering(a, *, pivoting, full_matrices):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
k = m if full_matrices else core.min_dim(m, n)
q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)),
(*batch_dims, m, k),
(len(batch_dims), len(batch_dims) + 1))
r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype)
if pivoting:
p = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
return q, r, p
return q, r

r, taus = geqrf(a)
if pivoting:
jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
r, p, taus = geqp3(a, jpvt)
p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1.
else:
r, taus = geqrf(a)

if m < n:
q = householder_product(r[..., :m, :m], taus)
elif full_matrices:
Expand All @@ -2002,6 +2099,8 @@ def _qr_lowering(a, *, full_matrices):
q = householder_product(r, taus)
r = r[..., :n, :n]
r = _triu(r)
if pivoting:
return q, r, p
return q, r


Expand All @@ -2015,7 +2114,6 @@ def _qr_lowering(a, *, full_matrices):

mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering))


# Singular value decomposition
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None,
algorithm=None):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def svd(

reduce_to_square = False
if full_matrices:
q_full, a_full = lax.linalg.qr(a, full_matrices=True)
q_full, a_full = lax.linalg.qr(a, pivoting=False, full_matrices=True)
q = q_full[:, :n]
u_out_null = q_full[:, n:]
a = a_full[:n, :]
Expand All @@ -206,7 +206,7 @@ def svd(
# The constant `1.15` comes from Yuji Nakatsukasa's implementation
# https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav
if m > 1.15 * n:
q, a = lax.linalg.qr(a, full_matrices=False)
q, a = lax.linalg.qr(a, pivoting=False, full_matrices=False)
reduce_to_square = True

if not compute_uv:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
full_matrices = True
else:
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
q, r = lax_linalg.qr(a, pivoting=False, full_matrices=full_matrices)
if mode == "r":
return r
return QRResult(q, r)
Expand Down
93 changes: 72 additions & 21 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,52 +820,96 @@ def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
del overwrite_a, check_finite # unused
return _lu(a, permute_l)


@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[False]
) -> tuple[Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[True]
) -> tuple[Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[False]
) -> tuple[Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[True]
) -> tuple[Array, Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> tuple[Array]: ...
def _qr(a: ArrayLike, mode: str, pivoting: Literal[False]
) -> tuple[Array] | tuple[Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> tuple[Array, Array]: ...
def _qr(a: ArrayLike, mode: str, pivoting: Literal[True]
) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]: ...
def _qr(a: ArrayLike, mode: str, pivoting: bool
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ...


@partial(jit, static_argnames=('mode', 'pivoting'))
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]:
if pivoting:
raise NotImplementedError(
"The pivoting=True case of qr is not implemented.")
def _qr(a: ArrayLike, mode: str, pivoting: bool
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]:
if mode in ("full", "r"):
full_matrices = True
elif mode == "economic":
full_matrices = False
else:
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
a, = promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
q, r, *p = lax_linalg.qr(a, pivoting=pivoting, full_matrices=full_matrices)
if mode == "r":
if pivoting:
return r, p[0]
return (r,)
if pivoting:
return q, r, p[0]
return q, r


@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: Literal["full", "economic"] = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ...
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: Literal[False] = False,
check_finite: bool = True) -> tuple[Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: Literal[True] = True,
check_finite: bool = True) -> tuple[Array, Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool, lwork: Any, mode: Literal["r"],
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: bool = False,
check_finite: bool = True
) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"],
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: Literal[False] = False, check_finite: bool = True
) -> tuple[Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: Literal[True] = True, check_finite: bool = True
) -> tuple[Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ...
pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ...


def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]:
pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]:
"""Compute the QR decomposition of an array

JAX implementation of :func:`scipy.linalg.qr`.
Expand All @@ -888,22 +932,29 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "
- ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``,
where K = min(M, N).

pivoting: Not implemented in JAX.
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``, compute
the column-pivoted decomposition ``A[:, P] = Q @ R``, where ``P`` is chosen such
that the diagonal of ``R`` is non-increasing.
overwrite_a: unused in JAX
lwork: unused in JAX
check_finite: unused in JAX

Returns:
A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``,
where:
A tuple ``(Q, R)`` or ``(Q, R, P)``, if ``mode`` is not ``"r"`` and ``pivoting`` is
respectively ``False`` or ``True``, otherwise an array ``R`` or tuple ``(R, P)`` if
mode is ``"r"``, and ``pivoting`` is respectively ``False`` or ``True``, where:

- ``Q`` is an orthogonal matrix of shape ``(..., M, M)`` (if ``mode`` is ``"full"``)
or ``(..., M, K)`` (if ``mode`` is ``"economic"``).
or ``(..., M, K)`` (if ``mode`` is ``"economic"``),
- ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is
``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``)
``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``),
- ``P`` is an index vector of shape ``(..., N)``.

with ``K = min(M, N)``.

Notes:
- At present, pivoting is only implemented on CPU backends.

See also:
- :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API
- :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API
Expand Down
Loading
Loading