Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SVD #598

Merged
merged 2 commits into from
Nov 12, 2024
Merged

SVD #598

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ A few of the [linear algebra extension](https://data-apis.org/array-api/2022.12/
| | `qr` | :white_check_mark: | | |
| | `slogdet` | :x: | | |
| | `solve` | :x: | | |
| | `svd` | :x: | | |
| | `svdvals` | :x: | | |
| | `svd` | :white_check_mark: | | |
| | `svdvals` | :white_check_mark: | | |
| | `tensordot` | :white_check_mark: | | |
| | `trace` | :x: | | |
| | `vecdot` | :white_check_mark: | | |
Expand Down
105 changes: 84 additions & 21 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import NamedTuple

from cubed.array_api.array_object import Array

# These functions are in both the main and linalg namespaces
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import _floating_dtypes

# These functions are in both the main and linalg namespaces
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
matmul,
matrix_transpose,
tensordot,
vecdot,
)
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import blockwise, general_blockwise, merge_chunks
from cubed.core.ops import blockwise, general_blockwise, merge_chunks, squeeze
from cubed.utils import array_memory, get_item


Expand All @@ -27,6 +27,12 @@ class QRResult(NamedTuple):
R: Array


class SVDResult(NamedTuple):
U: Array
S: Array
Vh: Array


def qr(x, /, *, mode="reduced") -> QRResult:
if x.ndim != 2:
raise ValueError("qr requires x to have 2 dimensions.")
Expand All @@ -43,10 +49,11 @@ def qr(x, /, *, mode="reduced") -> QRResult:
"Consider rechunking so there is only a single column chunk."
)

return tsqr(x)
Q, R, _, _, _ = tsqr(x)
return QRResult(Q, R)


def tsqr(x) -> QRResult:
def tsqr(x, compute_svd=False, finalize_svd=True):
"""Direct Tall-and-Skinny QR algorithm

From:
Expand All @@ -57,18 +64,22 @@ def tsqr(x) -> QRResult:
https://arxiv.org/abs/1301.1071
"""

# follows Algorithm 2 from Benson et al
# follows Algorithm 2 from Benson et al, modified for SVD
Q1, R1 = _qr_first_step(x)

if _r1_is_too_big(R1):
R1 = _rechunk_r1(R1)
Q2, R2 = tsqr(R1)
Q2, R2, U, S, Vh = tsqr(R1, compute_svd=compute_svd, finalize_svd=False)
else:
Q2, R2 = _qr_second_step(R1)
Q2, R2, U, S, Vh = _qr_second_step(R1, compute_svd=compute_svd)

Q, R = _qr_third_step(Q1, Q2), R2

return QRResult(Q, R)
if compute_svd and finalize_svd:
U = Q @ U # fourth step (SVD only)
S = squeeze(S, axis=1) # remove extra dim

return Q, R, U, S, Vh


def _qr_first_step(A):
Expand Down Expand Up @@ -108,7 +119,7 @@ def _rechunk_r1(R1, split_every=4):
return merge_chunks(R1, chunks=chunks)


def _qr_second_step(R1):
def _qr_second_step(R1, compute_svd=False):
R1_single = _merge_into_single_chunk(R1)

Q2_shape = R1.shape
Expand All @@ -117,17 +128,38 @@ def _qr_second_step(R1):
n = R1.shape[1]
R2_shape = (n, n)
R2_chunks = R2_shape # single chunk
# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2 = map_blocks_multiple_outputs(
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
return QRResult(Q2, R2)

if not compute_svd:
# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2 = map_blocks_multiple_outputs(
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
return Q2, R2, None, None, None
else:
U_shape = (n, n)
U_chunks = U_shape
S_shape = (n, 1) # extra dim since multiple outputs must have same numblocks
S_chunks = S_shape
Vh_shape = (n, n)
Vh_chunks = Vh_shape

# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2, U, S, Vh = map_blocks_multiple_outputs(
_qr2,
R1_single,
shapes=[Q2_shape, R2_shape, U_shape, S_shape, Vh_shape],
dtypes=[R1.dtype, R1.dtype, R1.dtype, R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks, U_chunks, S_chunks, Vh_chunks],
extra_projected_mem=extra_projected_mem,
)
return Q2, R2, U, S, Vh


def _merge_into_single_chunk(x, split_every=4):
Expand All @@ -138,6 +170,13 @@ def _merge_into_single_chunk(x, split_every=4):
return x


def _qr2(a):
Q, R = nxp.linalg.qr(a)
U, S, Vh = nxp.linalg.svd(R)
S = S[:, nxp.newaxis] # add extra dim
return Q, R, U, S, Vh


def _qr_third_step(Q1, Q2):
m, n = Q1.chunksize
k, _ = Q1.numblocks
Expand Down Expand Up @@ -174,6 +213,30 @@ def _q_matmul(a1, a2, q2_chunks=None, block_id=None):
return q1 @ q2


def svd(x, /, *, full_matrices=True) -> SVDResult:
if full_matrices:
raise ValueError("Cubed arrays only support using full_matrices=False")

nb = x.numblocks
# TODO: optimize case nb[0] == nb[1] == 1
if nb[0] > nb[1]:
_, _, U, S, Vh = tsqr(x, compute_svd=True)
truncate = x.shape[0] < x.shape[1]
else:
_, _, Vht, S, Ut = tsqr(x.T, compute_svd=True)
U, S, Vh = Ut.T, S, Vht.T
truncate = x.shape[0] > x.shape[1]
if truncate: # from dask
k = min(x.shape)
U, Vh = U[:, :k], Vh[:k, :]
return SVDResult(U, S, Vh)


def svdvals(x, /):
_, S, _ = svd(x, full_matrices=False)
return S


def map_blocks_multiple_outputs(
func,
*args,
Expand Down
44 changes: 44 additions & 0 deletions cubed/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,47 @@ def test_qr_chunking():
match=r"qr only supports tall-and-skinny \(single column chunk\) arrays.",
):
xp.linalg.qr(A)


def test_svd():
A = np.reshape(np.arange(32, dtype=np.float64), (16, 2))

U, S, Vh = xp.linalg.svd(xp.asarray(A, chunks=(4, 2)), full_matrices=False)
U, S, Vh = cubed.compute(U, S, Vh)

assert_allclose(U * S @ Vh, A, atol=1e-08)
assert_allclose(U.T @ U, np.eye(2, 2), atol=1e-08) # U must be orthonormal
assert_allclose(Vh @ Vh.T, np.eye(2, 2), atol=1e-08) # Vh must be orthonormal


def test_svd_recursion():
A = np.reshape(np.arange(128, dtype=np.float64), (64, 2))

# find a memory setting where recursion happens
found = False
for factor in range(4, 16):
spec = cubed.Spec(allowed_mem=128 * factor, reserved_mem=0)

try:
U, S, Vh = xp.linalg.svd(
xp.asarray(A, chunks=(8, 2), spec=spec), full_matrices=False
)

found = True
plan_unopt = arrays_to_plan(U, S, Vh)._finalize()
assert plan_unopt.num_primitive_ops() > 4 # more than without recursion

U, S, Vh = cubed.compute(U, S, Vh)

assert_allclose(U * S @ Vh, A, atol=1e-08)
assert_allclose(U.T @ U, np.eye(2, 2), atol=1e-08) # U must be orthonormal
assert_allclose(
Vh @ Vh.T, np.eye(2, 2), atol=1e-08
) # Vh must be orthonormal

break

except ValueError:
pass # not enough memory

assert found
Loading