Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch svd #28770

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f7f499e
fixed the potentially wrong namedtuple definitions in the svd backend…
Jun 20, 2024
3db7f17
try to fix the blas_and_lapack_ops.py.svd with correct output namedtu…
Jun 20, 2024
46d180a
try to fix the blas_and_lapack_ops.py.svd with correct output namedtu…
Jun 20, 2024
dce10a6
replace the unimplemented tensor.mH used to the implemented adjoint, …
Jun 30, 2024
0c13ce6
update test of torch.blas_and_lapack_ops.svd to calculate the validit…
Jul 3, 2024
4d0851d
small fix
Jul 3, 2024
3b3670f
small fix
Jul 3, 2024
71fed6b
updated the test for torch.linalg.svd
Jul 3, 2024
3fdf4dd
find that jax.lax.linalg.svd has a argument "subset_by_index" missing
Jul 3, 2024
c5b5904
fixed the skipping torch svd tests according to suggestion, no longer…
Jul 3, 2024
efe9a5a
tests are partially passing, though for torch backend, "RuntimeError:…
Jul 4, 2024
26aeba0
fix test of numpy.linalg.decomposition.svd as it returns a svd object…
Jul 7, 2024
b9cf1cd
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 14, 2024
3bb5b66
now only torch backend of jax.numpy.linalg.svd is failing due to "Run…
Jul 15, 2024
8b29eb7
all tests for tesnorflow.linalg.svd are passing
Jul 15, 2024
bc30d7d
try to fix the two svd function in torch frontend, now the only probl…
Jul 16, 2024
bed8f77
applied the suggested fix to torch svd tests, they are all passing now
Jul 16, 2024
316986e
make namedtuple definition more simple as suggested
Jul 16, 2024
e0268c6
tried to fix jax.lac.linalg.svd. p.s. there is no implementation of s…
Jul 16, 2024
dcab2c1
fixed jax.numpy.linalg.svd, all tests are passing, but jax.lax.linalg…
Jul 16, 2024
da4a78b
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 18, 2024
ac7a60a
fixing numpy.linalg.decompositions.svd
Jul 18, 2024
9b4c161
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 31, 2024
6c1c39c
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Aug 12, 2024
8e927a4
fixed ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py::te…
Aug 16, 2024
37272a8
fixed ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_d…
Aug 16, 2024
dc90073
Fixing ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops…
Aug 18, 2024
65c902f
fixed ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_lina…
Aug 20, 2024
c92b6be
fixing ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops…
Aug 20, 2024
c1a6632
changed ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_la…
Aug 22, 2024
9b75cd1
changed all the torch's test_torch_svd so that complex number inputs …
Aug 22, 2024
5c9de15
try to update torch and tensorflow's svd functions as they somehow re…
Aug 22, 2024
2f76f66
seems like should not use svdvals as it always return a not complex v…
Aug 22, 2024
412e60c
fixed jax's svd to teat for complex input. though only jax.lax.linalg…
Aug 23, 2024
ee58d83
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Aug 28, 2024
9433e91
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 3, 2024
c7d9ddc
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 6, 2024
c89223c
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 28, 2024
99b7266
small update on test_torch.test_tensor.test_torch_svd
Daniel4078 Sep 28, 2024
00f7754
Update test_blas_and_lapack_ops.py
Daniel4078 Sep 28, 2024
49db616
Update test_linalg.py
Daniel4078 Sep 28, 2024
637652e
Update test_linalg.py
Daniel4078 Sep 30, 2024
390b00b
Update test_tensor.py
Daniel4078 Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def svd(
x: JaxArray, /, *, compute_uv: bool = True, full_matrices: bool = True
) -> Union[JaxArray, Tuple[JaxArray, ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")
results = namedtuple("svd", 'U S Vh')
U, D, VT = jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return results(U, D, VT)
else:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def svd(
x: np.ndarray, /, *, compute_uv: bool = True, full_matrices: bool = True
) -> Union[np.ndarray, Tuple[np.ndarray, ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")
results = namedtuple("svd", 'U S Vh')
U, D, VT = np.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return results(U, D, VT)
else:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/paddle/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def svd(
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
ret = paddle.linalg.svd(x, full_matrices=full_matrices)
if compute_uv:
results = namedtuple("svd", "U S Vh")
results = namedtuple("svd", 'U S Vh')
return results(*ret)
else:
results = namedtuple("svd", "S")
Expand Down
3 changes: 1 addition & 2 deletions ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,7 @@ def svd(
compute_uv: bool = True,
) -> Union[Union[tf.Tensor, tf.Variable], Tuple[Union[tf.Tensor, tf.Variable], ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")

results = namedtuple("svd", 'U S Vh')
batch_shape = tf.shape(x)[:-2]
num_batch_dims = len(batch_shape)
transpose_dims = list(range(num_batch_dims)) + [
Expand Down
9 changes: 3 additions & 6 deletions ivy/functional/backends/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,13 @@ def svd(
x: torch.Tensor, /, *, full_matrices: bool = True, compute_uv: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")

results = namedtuple("svd", 'U S Vh')
U, D, VT = torch.linalg.svd(x, full_matrices=full_matrices)
return results(U, D, VT)
else:
results = namedtuple("svd", "S")
svd = torch.linalg.svd(x, full_matrices=full_matrices)
# torch.linalg.svd returns a tuple with U, S, and Vh
D = svd[1]
return results(D)
s = torch.linalg.svdvals(x)
return results(s)


@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version)
Expand Down
32 changes: 27 additions & 5 deletions ivy/functional/frontends/jax/lax/linalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ivy
from ivy.functional.frontends.jax.func_wrapper import to_ivy_arrays_and_back
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes


@to_ivy_arrays_and_back
Expand Down Expand Up @@ -44,7 +44,29 @@ def qr(x, /, *, full_matrices=False):


@to_ivy_arrays_and_back
def svd(x, /, *, full_matrices=True, compute_uv=True):
if not compute_uv:
return ivy.svdvals(x)
return ivy.svd(x, full_matrices=full_matrices)
@with_supported_dtypes(
{
"0.4.14 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"jax",
)
def svd(x, /, *, full_matrices=True, compute_uv=True, subset_by_index=None):
# TODO: handle subset_by_index
if ivy.is_complex_dtype(x.dtype):
d = ivy.complex128
else:
d = ivy.float64
if compute_uv:
svd = ivy.svd(x, compute_uv=compute_uv, full_matrices=full_matrices)
return tuple(
[ivy.astype(svd.U, d), ivy.astype(svd.S, d), ivy.astype(svd.Vh, d)]
)
else:
return ivy.astype(ivy.svdvals(x), ivy.float64)
28 changes: 25 additions & 3 deletions ivy/functional/frontends/jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,32 @@ def solve(a, b):


@to_ivy_arrays_and_back
@with_supported_dtypes(
{
"0.4.24 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"jax",
)
def svd(a, /, *, full_matrices=True, compute_uv=True, hermitian=None):
if not compute_uv:
return ivy.svdvals(a)
return ivy.svd(a, full_matrices=full_matrices)
# TODO: handle hermitian
if ivy.is_complex_dtype(a.dtype):
d = ivy.complex128
else:
d = ivy.float64
if compute_uv:
svd = ivy.svd(a, compute_uv=compute_uv, full_matrices=full_matrices)
return tuple(
[ivy.astype(svd.U, d), ivy.astype(svd.S, d), ivy.astype(svd.Vh, d)]
)
else:
return ivy.astype(ivy.svdvals(a), ivy.float64)


@to_ivy_arrays_and_back
Expand Down
21 changes: 19 additions & 2 deletions ivy/functional/frontends/numpy/linalg/decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# local
import ivy
from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back
from ivy.func_wrapper import with_supported_dtypes


@to_ivy_arrays_and_back
Expand All @@ -14,6 +15,22 @@ def qr(a, mode="reduced"):


@to_ivy_arrays_and_back
@with_supported_dtypes(
{
"1.26.3 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"numpy",
)
def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
# Todo: conpute_uv and hermitian handling
return ivy.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
# Todo: hermitian handling
if compute_uv:
return ivy.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
else:
return ivy.astype(ivy.svdvals(a), a.dtype)
26 changes: 25 additions & 1 deletion ivy/functional/frontends/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,32 @@ def solve(matrix, rhs, /, *, adjoint=False, name=None):


@to_ivy_arrays_and_back
@with_supported_dtypes(
{
"2.15.0 and below": (
"float32",
"float64",
"half",
"complex32",
"complex64",
"complex128",
)
},
"tensorflow",
)
def svd(a, /, *, full_matrices=False, compute_uv=True, name=None):
return ivy.svd(a, compute_uv=compute_uv, full_matrices=full_matrices)
if ivy.is_complex_dtype(a.dtype):
d = ivy.complex128
else:
d = ivy.float64
if compute_uv:
svd = ivy.svd(a, compute_uv=compute_uv, full_matrices=full_matrices)
return tuple(
[ivy.astype(svd.S, d), ivy.astype(svd.U, d), ivy.astype(svd.Vh.T, d)]
)
else:
svd = ivy.svd(a, compute_uv=compute_uv, full_matrices=full_matrices)
return ivy.astype(svd.S, d)


@to_ivy_arrays_and_back
Expand Down
18 changes: 15 additions & 3 deletions ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,12 +832,24 @@ def Sum(*, input, axis, keep_dims=False, name="Sum"):


@with_supported_dtypes(
{"2.15.0 and below": ("float64", "float128", "halfcomplex64", "complex128")},
{
"2.15.0 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"tensorflow",
)
@to_ivy_arrays_and_back
def Svd(*, input, full_matrices=False, compute_uv=True, name=None):
return ivy.svd(input, compute_uv=compute_uv, full_matrices=full_matrices)
def Svd(*, input, full_matrices=False, compute_uv=True, name="Svd"):
ret = ivy.svd(input, compute_uv=compute_uv, full_matrices=full_matrices)
if not compute_uv:
return (ret.S, None, None)
return (ret.S, ret.U, ivy.adjoint(ret.Vh))


@to_ivy_arrays_and_back
Expand Down
34 changes: 29 additions & 5 deletions ivy/functional/frontends/torch/blas_and_lapack_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# global
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
import ivy.functional.frontends.torch as torch_frontend
from collections import namedtuple
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back


Expand Down Expand Up @@ -189,13 +190,36 @@ def slogdet(A, *, out=None):
return torch_frontend.linalg.slogdet(A, out=out)


@with_supported_dtypes(
{
"2.2 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"torch",
)
@to_ivy_arrays_and_back
def svd(input, some=True, compute_uv=True, *, out=None):
# TODO: add compute_uv
if some:
ret = ivy.svd(input, full_matrices=False)
retu = ivy.svd(input, full_matrices=not some, compute_uv=compute_uv)
results = namedtuple("svd", "U S V")
if compute_uv:
ret = results(retu[0], retu[1], ivy.adjoint(retu[2]))
else:
ret = ivy.svd(input, full_matrices=True)
shape = list(input.shape)
shape1 = shape
shape2 = shape
shape1[-2] = shape[-1]
shape2[-1] = shape[-2]
ret = results(
ivy.zeros(shape1, device=input.device, dtype=input.dtype),
ivy.astype(retu[0], input.dtype),
ivy.zeros(shape2, device=input.device, dtype=input.dtype),
)
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
Expand Down
25 changes: 22 additions & 3 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,30 @@ def solve_ex(A, B, *, left=True, check_errors=False, out=None):

@to_ivy_arrays_and_back
@with_supported_dtypes(
{"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
{
"2.2 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"torch",
)
def svd(A, /, *, full_matrices=True, driver=None, out=None):
# TODO: add handling for driver and out
return ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
# TODO: add handling for driver
USVh = ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
if ivy.is_complex_dtype(A.dtype):
d = ivy.complex64
else:
d = ivy.float32
nt = namedtuple("svd", "U S Vh")
ret = nt(ivy.astype(USVh.U, d), ivy.astype(USVh.S, d), ivy.astype(USVh.Vh, d))
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret


@to_ivy_arrays_and_back
Expand Down
14 changes: 13 additions & 1 deletion ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,7 +2109,19 @@ def adjoint(self):
def conj(self):
return torch_frontend.conj(self)

@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch")
@with_supported_dtypes(
{
"2.2 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"torch",
)
def svd(self, some=True, compute_uv=True, *, out=None):
return torch_frontend.svd(self, some=some, compute_uv=compute_uv, out=out)

Expand Down
5 changes: 1 addition & 4 deletions ivy/functional/ivy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,15 +2130,12 @@ def svd(
If ``True`` then left and right singular vectors will be computed and returned
in ``U`` and ``Vh``, respectively. Otherwise, only the singular values will be
computed, which can be significantly faster.
.. note::
with backend set as torch, svd with still compute left and right singular
vectors irrespective of the value of compute_uv, however Ivy will still
only return the singular values.

Returns
-------
.. note::
once complex numbers are supported, each square matrix must be Hermitian.
In addition, the return will be a namedtuple ``(S)`` when compute_uv is ``False``

ret
a namedtuple ``(U, S, Vh)`` whose
Expand Down
Loading
Loading