Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch svd #28770

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f7f499e
fixed the potentially wrong namedtuple definitions in the svd backend…
Jun 20, 2024
3db7f17
try to fix the blas_and_lapack_ops.py.svd with correct output namedtu…
Jun 20, 2024
46d180a
try to fix the blas_and_lapack_ops.py.svd with correct output namedtu…
Jun 20, 2024
dce10a6
replace the unimplemented tensor.mH used to the implemented adjoint, …
Jun 30, 2024
0c13ce6
update test of torch.blas_and_lapack_ops.svd to calculate the validit…
Jul 3, 2024
4d0851d
small fix
Jul 3, 2024
3b3670f
small fix
Jul 3, 2024
71fed6b
updated the test for torch.linalg.svd
Jul 3, 2024
3fdf4dd
find that jax.lax.linalg.svd has a argument "subset_by_index" missing
Jul 3, 2024
c5b5904
fixed the skipping torch svd tests according to suggestion, no longer…
Jul 3, 2024
efe9a5a
tests are partially passing, though for torch backend, "RuntimeError:…
Jul 4, 2024
26aeba0
fix test of numpy.linalg.decomposition.svd as it returns a svd object…
Jul 7, 2024
b9cf1cd
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 14, 2024
3bb5b66
now only torch backend of jax.numpy.linalg.svd is failing due to "Run…
Jul 15, 2024
8b29eb7
all tests for tesnorflow.linalg.svd are passing
Jul 15, 2024
bc30d7d
try to fix the two svd function in torch frontend, now the only probl…
Jul 16, 2024
bed8f77
applied the suggested fix to torch svd tests, they are all passing now
Jul 16, 2024
316986e
make namedtuple definition more simple as suggested
Jul 16, 2024
e0268c6
tried to fix jax.lac.linalg.svd. p.s. there is no implementation of s…
Jul 16, 2024
dcab2c1
fixed jax.numpy.linalg.svd, all tests are passing, but jax.lax.linalg…
Jul 16, 2024
da4a78b
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 18, 2024
ac7a60a
fixing numpy.linalg.decompositions.svd
Jul 18, 2024
9b4c161
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 31, 2024
6c1c39c
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Aug 12, 2024
8e927a4
fixed ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py::te…
Aug 16, 2024
37272a8
fixed ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_d…
Aug 16, 2024
dc90073
Fixing ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops…
Aug 18, 2024
65c902f
fixed ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_lina…
Aug 20, 2024
c92b6be
fixing ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops…
Aug 20, 2024
c1a6632
changed ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_la…
Aug 22, 2024
9b75cd1
changed all the torch's test_torch_svd so that complex number inputs …
Aug 22, 2024
5c9de15
try to update torch and tensorflow's svd functions as they somehow re…
Aug 22, 2024
2f76f66
seems like should not use svdvals as it always return a not complex v…
Aug 22, 2024
412e60c
fixed jax's svd to teat for complex input. though only jax.lax.linalg…
Aug 23, 2024
ee58d83
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Aug 28, 2024
9433e91
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 3, 2024
c7d9ddc
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 6, 2024
c89223c
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 28, 2024
99b7266
small update on test_torch.test_tensor.test_torch_svd
Daniel4078 Sep 28, 2024
00f7754
Update test_blas_and_lapack_ops.py
Daniel4078 Sep 28, 2024
49db616
Update test_linalg.py
Daniel4078 Sep 28, 2024
637652e
Update test_linalg.py
Daniel4078 Sep 30, 2024
390b00b
Update test_tensor.py
Daniel4078 Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def svd(
x: JaxArray, /, *, compute_uv: bool = True, full_matrices: bool = True
) -> Union[JaxArray, Tuple[JaxArray, ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")
results = namedtuple("svd", ['U', 'S', 'Vh'])
Daniel4078 marked this conversation as resolved.
Show resolved Hide resolved
U, D, VT = jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return results(U, D, VT)
else:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def svd(
x: np.ndarray, /, *, compute_uv: bool = True, full_matrices: bool = True
) -> Union[np.ndarray, Tuple[np.ndarray, ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")
results = namedtuple("svd", ['U', 'S', 'Vh'])
U, D, VT = np.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return results(U, D, VT)
else:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/paddle/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def svd(
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
ret = paddle.linalg.svd(x, full_matrices=full_matrices)
if compute_uv:
results = namedtuple("svd", "U S Vh")
results = namedtuple("svd", ['U', 'S', 'Vh'])
return results(*ret)
else:
results = namedtuple("svd", "S")
Expand Down
3 changes: 1 addition & 2 deletions ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,7 @@ def svd(
compute_uv: bool = True,
) -> Union[Union[tf.Tensor, tf.Variable], Tuple[Union[tf.Tensor, tf.Variable], ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")

results = namedtuple("svd", ['U', 'S', 'Vh'])
batch_shape = tf.shape(x)[:-2]
num_batch_dims = len(batch_shape)
transpose_dims = list(range(num_batch_dims)) + [
Expand Down
9 changes: 3 additions & 6 deletions ivy/functional/backends/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,13 @@ def svd(
x: torch.Tensor, /, *, full_matrices: bool = True, compute_uv: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")

results = namedtuple("svd", ['U', 'S', 'Vh'])
U, D, VT = torch.linalg.svd(x, full_matrices=full_matrices)
return results(U, D, VT)
else:
results = namedtuple("svd", "S")
svd = torch.linalg.svd(x, full_matrices=full_matrices)
# torch.linalg.svd returns a tuple with U, S, and Vh
D = svd[1]
return results(D)
s = torch.linalg.svdvals(x)
return results(s)


@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version)
Expand Down
15 changes: 11 additions & 4 deletions ivy/functional/frontends/torch/blas_and_lapack_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
import ivy.functional.frontends.torch as torch_frontend
from collections import namedtuple
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back


Expand Down Expand Up @@ -191,11 +192,17 @@ def slogdet(A, *, out=None):

@to_ivy_arrays_and_back
def svd(input, some=True, compute_uv=True, *, out=None):
# TODO: add compute_uv
if some:
ret = ivy.svd(input, full_matrices=False)
retu = ivy.svd(input, full_matrices=not some, compute_uv=compute_uv)
results = namedtuple("svd", ['U', 'S', 'V'])
if compute_uv:
ret = results(retu[0], retu[1], ivy.adjoint(retu[2]))
else:
ret = ivy.svd(input, full_matrices=True)
shape = list(input.shape)
shape1 = shape
shape2 = shape
shape1[-2] = shape[-1]
shape2[-1] = shape[-2]
ret = results(ivy.zeros(shape1, device=input.device, dtype=input.dtype), retu[0], ivy.zeros(shape2, device=input.device, dtype=input.dtype))
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
Expand Down
7 changes: 5 additions & 2 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,11 @@ def solve_ex(A, B, *, left=True, check_errors=False, out=None):
{"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def svd(A, /, *, full_matrices=True, driver=None, out=None):
# TODO: add handling for driver and out
return ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
# TODO: add handling for driver
ret = ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret


@to_ivy_arrays_and_back
Expand Down
5 changes: 1 addition & 4 deletions ivy/functional/ivy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,15 +2130,12 @@ def svd(
If ``True`` then left and right singular vectors will be computed and returned
in ``U`` and ``Vh``, respectively. Otherwise, only the singular values will be
computed, which can be significantly faster.
.. note::
with backend set as torch, svd with still compute left and right singular
vectors irrespective of the value of compute_uv, however Ivy will still
only return the singular values.

Returns
-------
.. note::
once complex numbers are supported, each square matrix must be Hermitian.
In addition, the return will be a namedtuple ``(S)`` when compute_uv is ``False``

ret
a namedtuple ``(U, S, Vh)`` whose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hypothesis import strategies as st, assume

# local
import ivy
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
from ivy_tests.test_ivy.helpers.hypothesis_helpers.general_helpers import (
Expand Down Expand Up @@ -848,37 +849,62 @@ def test_torch_qr(
@handle_frontend_test(
fn_tree="torch.svd",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float", index=1),
min_num_dims=3,
max_num_dims=5,
min_dim_size=2,
max_dim_size=5,
available_dtypes=helpers.get_dtypes("float"),
min_value=0,
max_value=10,
shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)),
),
some=st.booleans(),
compute=st.booleans(),
compute_uv=st.booleans(),
)
def test_torch_svd(
dtype_and_x,
some,
compute,
on_device,
fn_tree,
compute_uv,
frontend,
test_flags,
fn_tree,
backend_fw,
on_device,
):
dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=dtype,
input_dtype, x = dtype_and_x
x = np.asarray(x[0], dtype=input_dtype[0])
# make symmetric positive definite beforehand
x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3
ret, frontend_ret = helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
test_values=False,
input = x,
some=some,
compute_uv=compute,
compute_uv=compute_uv,
)
ret = [ivy.to_numpy(x) for x in ret]
frontend_ret = [np.asarray(x) for x in frontend_ret]

u, s, v = ret
frontend_u, frontend_s, frontend_v = frontend_ret

if compute_uv:
helpers.assert_all_close(
ret_np=frontend_u @ np.diag(frontend_s) @ frontend_v.T,
ret_from_gt_np=u @ np.diag(s) @ v.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
else:
helpers.assert_all_close(
ret_np=frontend_s,
ret_from_gt_np=s,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)


@handle_frontend_test(
Expand Down
37 changes: 25 additions & 12 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,15 @@ def test_torch_solve_ex(
# svd
@handle_frontend_test(
fn_tree="torch.linalg.svd",
dtype_and_x=_get_dtype_and_matrix(square=True),
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=0,
max_value=10,
min_num_dims=2,
max_num_dims=5,
min_dim_size=1,
max_dim_size=5,
),
full_matrices=st.booleans(),
)
def test_torch_svd(
Expand All @@ -1272,25 +1280,30 @@ def test_torch_svd(
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
atol=1e-03,
rtol=1e-05,
A=x,
full_matrices=full_matrices,
)
ret = [ivy.to_numpy(x) for x in ret]
Daniel4078 marked this conversation as resolved.
Show resolved Hide resolved
frontend_ret = [np.asarray(x) for x in frontend_ret]

u, s, vh = ret
frontend_u, frontend_s, frontend_vh = frontend_ret
if full_matrices:
helpers.assert_all_close(
ret_np=frontend_u[...,:frontend_s.shape[0]] @ np.diag(frontend_s) @ frontend_vh.T,
ret_from_gt_np=u[...,:s.shape[0]] @ np.diag(s) @ vh.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
else:
helpers.assert_all_close(
ret_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T,
ret_from_gt_np=u @ np.diag(s) @ vh.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)

assert_all_close(
ret_np=u @ np.diag(s) @ vh,
ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh,
rtol=1e-2,
atol=1e-2,
ground_truth_backend=frontend,
backend=backend_fw,
)


# svdvals
Expand Down
Loading