Skip to content

Commit

Permalink
Merge pull request #20282 from tttc3:pivoted-qr
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714053620
  • Loading branch information
Google-ML-Automation committed Jan 10, 2025
2 parents 1fe72ee + c89be05 commit 564b6b0
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 46 deletions.
128 changes: 113 additions & 15 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,23 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
return lu, pivots, permutation


def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
@overload
def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True,
) -> tuple[Array, Array]:
...

@overload
def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True,
) -> tuple[Array, Array, Array]:
...

@overload
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
...

def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
"""QR decomposition.
Computes the QR decomposition
Expand All @@ -323,11 +339,15 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
Args:
x: A batch of matrices with shape ``[..., m, n]``.
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``,
compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P``
is chosen such that the diagonal of ``R`` is non-increasing. Currently
supported on CPU backends only.
full_matrices: Determines if full or reduced matrices are returned; see
below.
Returns:
A pair of arrays ``(q, r)``.
A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``.
Array ``q`` is a unitary (orthogonal) matrix,
with shape ``[..., m, m]`` if ``full_matrices=True``, or
Expand All @@ -336,8 +356,12 @@ def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
Array ``r`` is an upper-triangular matrix with shape ``[..., m, n]`` if
``full_matrices=True``, or ``[..., min(m, n), n]`` if
``full_matrices=False``.
Array ``p`` is an index vector with shape [..., n]
"""
q, r = qr_p.bind(x, full_matrices=full_matrices)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices)
if pivoting:
return q, r, p[0]
return q, r


Expand Down Expand Up @@ -1846,6 +1870,61 @@ def _geqrf_cpu_gpu_lowering(ctx, a, *, target_name_prefix: str):
platform='rocm')


def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]:
"""Computes the column-pivoted QR decomposition of a matrix.
Args:
a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type.
jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type,
Returns:
A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``,
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
such that ``a[:, jpvt] = q @ r``.
"""
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt)
return a_out, jpvt_out, taus

def _geqp3_abstract_eval(a, jpvt):
if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray):
raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: "
f"{a.aval}, {jpvt.aval}")
if a.ndim < 2:
raise ValueError("Argument to column-pivoted QR decomposition must have ndims >= 2")
*batch_dims, m, n = a.shape
*jpvt_batch_dims, jpvt_n = jpvt.shape
if batch_dims != jpvt_batch_dims or jpvt_n != n:
raise ValueError(f"Type mismatch for pivoted QR decomposition: {a=} {jpvt=}")
taus = a.update(shape=(*batch_dims, core.min_dim(m, n)))
return a, jpvt, taus

def _geqp3_batching_rule(batched_args, batch_dims):
a, jpvt = batched_args
b_a, b_jpvt = batch_dims
a = batching.moveaxis(a, b_a, 0)
jpvt = batching.moveaxis(jpvt, b_jpvt, 0)
return geqp3(a, jpvt), (0, 0, 0)

def _geqp3_cpu_lowering(ctx, a, jpvt):
a_aval, jpvt_aval = ctx.avals_in
batch_dims = a_aval.shape[:-2]
nb = len(batch_dims)
layout = [(nb, nb + 1) + tuple(range(nb - 1, -1, -1)), tuple(range(nb, -1, -1))]
result_layouts = layout + [tuple(range(nb, -1, -1))]
target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype)
rule = ffi.ffi_lowering(target_name, operand_layouts=layout,
result_layouts=result_layouts,
operand_output_aliases={0: 0, 1: 1})
return rule(ctx, a, jpvt)


geqp3_p = Primitive('geqp3')
geqp3_p.multiple_results = True
geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p))
geqp3_p.def_abstract_eval(_geqp3_abstract_eval)
batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule
mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu")

# householder_product: product of elementary Householder reflectors

def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
Expand Down Expand Up @@ -1938,32 +2017,37 @@ def _householder_product_cpu_gpu_lowering(ctx, a, taus, *,
platform='rocm')


def _qr_impl(operand, *, full_matrices):
q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r
def _qr_impl(operand, *, pivoting, full_matrices):
q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting,
full_matrices=full_matrices)
return (q, r, p[0]) if pivoting else (q, r)

def _qr_abstract_eval(operand, *, full_matrices):
def _qr_abstract_eval(operand, *, pivoting, full_matrices):
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
k = m if full_matrices else core.min_dim(m, n)
q = operand.update(shape=(*batch_dims, m, k))
r = operand.update(shape=(*batch_dims, k, n))
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32))
else:
q = operand
r = operand
return q, r
p = operand
return (q, r, p) if pivoting else (q, r)

def qr_jvp_rule(primals, tangents, *, full_matrices):
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
x, = primals
dx, = tangents
q, r = qr_p.bind(x, full_matrices=False)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False)
*_, m, n = x.shape
if m < n or (full_matrices and m != n):
raise NotImplementedError(
"Unimplemented case of QR decomposition derivative")
if pivoting:
dx = dx[..., p[0]]
dx_rinv = triangular_solve(r, dx) # Right side solve by default
qt_dx_rinv = _H(q) @ dx_rinv
qt_dx_rinv_lower = _tril(qt_dx_rinv, -1)
Expand All @@ -1973,25 +2057,38 @@ def qr_jvp_rule(primals, tangents, *, full_matrices):
do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype))
dq = q @ (do - qt_dx_rinv) + dx_rinv
dr = (qt_dx_rinv - do) @ r
if pivoting:
dp = ad_util.Zero.from_primal_value(p[0])
return (q, r, p[0]), (dq, dr, dp)
return (q, r), (dq, dr)

def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
out_axes = (0, 0, 0) if pivoting else (0, 0)
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes

def _qr_lowering(a, *, full_matrices):
def _qr_lowering(a, *, pivoting, full_matrices):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
k = m if full_matrices else core.min_dim(m, n)
q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)),
(*batch_dims, m, k),
(len(batch_dims), len(batch_dims) + 1))
r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype)
if pivoting:
p = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
return q, r, p
return q, r

r, taus = geqrf(a)
if pivoting:
jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
r, p, taus = geqp3(a, jpvt)
p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1.
else:
r, taus = geqrf(a)

if m < n:
q = householder_product(r[..., :m, :m], taus)
elif full_matrices:
Expand All @@ -2002,6 +2099,8 @@ def _qr_lowering(a, *, full_matrices):
q = householder_product(r, taus)
r = r[..., :n, :n]
r = _triu(r)
if pivoting:
return q, r, p
return q, r


Expand All @@ -2015,7 +2114,6 @@ def _qr_lowering(a, *, full_matrices):

mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering))


# Singular value decomposition
def _svd_impl(operand, *, full_matrices, compute_uv, subset_by_index=None,
algorithm=None):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def svd(

reduce_to_square = False
if full_matrices:
q_full, a_full = lax.linalg.qr(a, full_matrices=True)
q_full, a_full = lax.linalg.qr(a, pivoting=False, full_matrices=True)
q = q_full[:, :n]
u_out_null = q_full[:, n:]
a = a_full[:n, :]
Expand All @@ -206,7 +206,7 @@ def svd(
# The constant `1.15` comes from Yuji Nakatsukasa's implementation
# https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav
if m > 1.15 * n:
q, a = lax.linalg.qr(a, full_matrices=False)
q, a = lax.linalg.qr(a, pivoting=False, full_matrices=False)
reduce_to_square = True

if not compute_uv:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
full_matrices = True
else:
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
q, r = lax_linalg.qr(a, pivoting=False, full_matrices=full_matrices)
if mode == "r":
return r
return QRResult(q, r)
Expand Down
93 changes: 72 additions & 21 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,52 +820,96 @@ def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False,
del overwrite_a, check_finite # unused
return _lu(a, permute_l)


@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[False]
) -> tuple[Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: Literal[True]
) -> tuple[Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[False]
) -> tuple[Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: Literal[True]
) -> tuple[Array, Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["r"], pivoting: bool) -> tuple[Array]: ...
def _qr(a: ArrayLike, mode: str, pivoting: Literal[False]
) -> tuple[Array] | tuple[Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: Literal["full", "economic"], pivoting: bool) -> tuple[Array, Array]: ...
def _qr(a: ArrayLike, mode: str, pivoting: Literal[True]
) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...

@overload
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]: ...
def _qr(a: ArrayLike, mode: str, pivoting: bool
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ...


@partial(jit, static_argnames=('mode', 'pivoting'))
def _qr(a: ArrayLike, mode: str, pivoting: bool) -> tuple[Array] | tuple[Array, Array]:
if pivoting:
raise NotImplementedError(
"The pivoting=True case of qr is not implemented.")
def _qr(a: ArrayLike, mode: str, pivoting: bool
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]:
if mode in ("full", "r"):
full_matrices = True
elif mode == "economic":
full_matrices = False
else:
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
a, = promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
q, r, *p = lax_linalg.qr(a, pivoting=pivoting, full_matrices=full_matrices)
if mode == "r":
if pivoting:
return r, p[0]
return (r,)
if pivoting:
return q, r, p[0]
return q, r


@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: Literal["full", "economic"] = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ...
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: Literal[False] = False,
check_finite: bool = True) -> tuple[Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: Literal[True] = True,
check_finite: bool = True) -> tuple[Array, Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool, lwork: Any, mode: Literal["r"],
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["full", "economic"], pivoting: bool = False,
check_finite: bool = True
) -> tuple[Array, Array] | tuple[Array, Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal["r"],
pivoting: bool = False, check_finite: bool = True) -> tuple[Array]: ...
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: Literal[False] = False, check_finite: bool = True
) -> tuple[Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: Literal[True] = True, check_finite: bool = True
) -> tuple[Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *,
mode: Literal["r"], pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array]: ...

@overload
def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ...
pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]: ...


def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full",
pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]:
pivoting: bool = False, check_finite: bool = True
) -> tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]:
"""Compute the QR decomposition of an array
JAX implementation of :func:`scipy.linalg.qr`.
Expand All @@ -888,22 +932,29 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "
- ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``,
where K = min(M, N).
pivoting: Not implemented in JAX.
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``, compute
the column-pivoted decomposition ``A[:, P] = Q @ R``, where ``P`` is chosen such
that the diagonal of ``R`` is non-increasing.
overwrite_a: unused in JAX
lwork: unused in JAX
check_finite: unused in JAX
Returns:
A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``,
where:
A tuple ``(Q, R)`` or ``(Q, R, P)``, if ``mode`` is not ``"r"`` and ``pivoting`` is
respectively ``False`` or ``True``, otherwise an array ``R`` or tuple ``(R, P)`` if
mode is ``"r"``, and ``pivoting`` is respectively ``False`` or ``True``, where:
- ``Q`` is an orthogonal matrix of shape ``(..., M, M)`` (if ``mode`` is ``"full"``)
or ``(..., M, K)`` (if ``mode`` is ``"economic"``).
or ``(..., M, K)`` (if ``mode`` is ``"economic"``),
- ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is
``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``)
``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``),
- ``P`` is an index vector of shape ``(..., N)``.
with ``K = min(M, N)``.
Notes:
- At present, pivoting is only implemented on CPU backends.
See also:
- :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API
- :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API
Expand Down
Loading

0 comments on commit 564b6b0

Please sign in to comment.