diff --git a/ivy/functional/backends/tensorflow/linear_algebra.py b/ivy/functional/backends/tensorflow/linear_algebra.py index 0fcaa349572e..d1b5ed36872d 100644 --- a/ivy/functional/backends/tensorflow/linear_algebra.py +++ b/ivy/functional/backends/tensorflow/linear_algebra.py @@ -535,7 +535,6 @@ def svd( ) -> Union[Union[tf.Tensor, tf.Variable], Tuple[Union[tf.Tensor, tf.Variable], ...]]: if compute_uv: results = namedtuple("svd", "U S Vh") - batch_shape = tf.shape(x)[:-2] num_batch_dims = len(batch_shape) transpose_dims = list(range(num_batch_dims)) + [ diff --git a/ivy/functional/backends/torch/linear_algebra.py b/ivy/functional/backends/torch/linear_algebra.py index e8d960d313b3..9b708e697a26 100644 --- a/ivy/functional/backends/torch/linear_algebra.py +++ b/ivy/functional/backends/torch/linear_algebra.py @@ -415,15 +415,12 @@ def svd( ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: if compute_uv: results = namedtuple("svd", "U S Vh") - U, D, VT = torch.linalg.svd(x, full_matrices=full_matrices) return results(U, D, VT) else: results = namedtuple("svd", "S") - svd = torch.linalg.svd(x, full_matrices=full_matrices) - # torch.linalg.svd returns a tuple with U, S, and Vh - D = svd[1] - return results(D) + s = torch.linalg.svdvals(x) + return results(s) @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version) diff --git a/ivy/functional/frontends/torch/blas_and_lapack_ops.py b/ivy/functional/frontends/torch/blas_and_lapack_ops.py index a1c8cd9bbe77..d4e68a6c8e81 100644 --- a/ivy/functional/frontends/torch/blas_and_lapack_ops.py +++ b/ivy/functional/frontends/torch/blas_and_lapack_ops.py @@ -1,7 +1,8 @@ # global import ivy -from ivy.func_wrapper import with_unsupported_dtypes +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes import ivy.functional.frontends.torch as torch_frontend +from collections import namedtuple from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back @@ -190,12 +191,35 @@ def slogdet(A, *, out=None): @to_ivy_arrays_and_back +@with_supported_dtypes( + { + "2.2 and below": ( + "float64", + "float32", + "half", + "complex32", + "complex64", + "complex128", + ) + }, + "torch", +) def svd(input, some=True, compute_uv=True, *, out=None): - # TODO: add compute_uv - if some: - ret = ivy.svd(input, full_matrices=False) + retu = ivy.svd(input, full_matrices=not some, compute_uv=compute_uv) + results = namedtuple("svd", "U S V") + if compute_uv: + ret = results(retu[0], retu[1], ivy.adjoint(retu[2])) else: - ret = ivy.svd(input, full_matrices=True) + shape = list(input.shape) + shape1 = shape + shape2 = shape + shape1[-2] = shape[-1] + shape2[-1] = shape[-2] + ret = results( + ivy.zeros(shape1, device=input.device, dtype=input.dtype), + ivy.astype(retu[0], input.dtype), + ivy.zeros(shape2, device=input.device, dtype=input.dtype), + ) if ivy.exists(out): return ivy.inplace_update(out, ret) return ret diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 0174959d032b..ad951346831d 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -347,11 +347,30 @@ def solve_ex(A, B, *, left=True, check_errors=False, out=None): @to_ivy_arrays_and_back @with_supported_dtypes( - {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" + { + "2.2 and below": ( + "float64", + "float32", + "half", + "complex32", + "complex64", + "complex128", + ) + }, + "torch", ) def svd(A, /, *, full_matrices=True, driver=None, out=None): - # TODO: add handling for driver and out - return ivy.svd(A, compute_uv=True, full_matrices=full_matrices) + # TODO: add handling for driver + USVh = ivy.svd(A, compute_uv=True, full_matrices=full_matrices) + if ivy.is_complex_dtype(A.dtype): + d = ivy.complex64 + else: + d = ivy.float32 + nt = namedtuple("svd", "U S Vh") + ret = nt(ivy.astype(USVh.U, d), ivy.astype(USVh.S, d), ivy.astype(USVh.Vh, d)) + if ivy.exists(out): + return ivy.inplace_update(out, ret) + return ret @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index ef09a2942b1d..0e74da12de75 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -2109,7 +2109,19 @@ def adjoint(self): def conj(self): return torch_frontend.conj(self) - @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") + @with_supported_dtypes( + { + "2.2 and below": ( + "float64", + "float32", + "half", + "complex32", + "complex64", + "complex128", + ) + }, + "torch", + ) def svd(self, some=True, compute_uv=True, *, out=None): return torch_frontend.svd(self, some=some, compute_uv=compute_uv, out=out) diff --git a/ivy/functional/ivy/linear_algebra.py b/ivy/functional/ivy/linear_algebra.py index 4c61f9fdbbaa..2440918e7193 100644 --- a/ivy/functional/ivy/linear_algebra.py +++ b/ivy/functional/ivy/linear_algebra.py @@ -2130,15 +2130,12 @@ def svd( If ``True`` then left and right singular vectors will be computed and returned in ``U`` and ``Vh``, respectively. Otherwise, only the singular values will be computed, which can be significantly faster. - .. note:: - with backend set as torch, svd with still compute left and right singular - vectors irrespective of the value of compute_uv, however Ivy will still - only return the singular values. Returns ------- .. note:: once complex numbers are supported, each square matrix must be Hermitian. + In addition, the return will be a namedtuple ``(S)`` when compute_uv is ``False``. ret a namedtuple ``(U, S, Vh)`` whose diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py index f66853869c8e..2714fbfe7899 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py @@ -848,37 +848,75 @@ def test_torch_qr( @handle_frontend_test( fn_tree="torch.svd", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", index=1), - min_num_dims=3, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)), ), some=st.booleans(), - compute=st.booleans(), + compute_uv=st.booleans(), ) def test_torch_svd( dtype_and_x, some, - compute, + compute_uv, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, + input_dtype, x = dtype_and_x + x = np.asarray(x[0], dtype=input_dtype[0]) + # make symmetric positive definite + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + test_values=False, + input=x, some=some, - compute_uv=compute, - ) + compute_uv=compute_uv, + ) + if backend_fw == "torch": + frontend_ret = [x.detach() for x in frontend_ret] + ret = [x.detach() for x in ret] + ret = [np.asarray(x) for x in ret] + frontend_ret = [ + np.asarray(x.resolve_conj()) for x in frontend_ret + ] + u, s, v = ret + frontend_u, frontend_s, frontend_v = frontend_ret + if not compute_uv: + helpers.assert_all_close( + ret_np=frontend_s, + ret_from_gt_np=s, + atol=1e-04, + backend=backend_fw, + ground_truth_backend=frontend, + ) + elif not some: + helpers.assert_all_close( + ret_np=frontend_u @ np.diag(frontend_s) @ frontend_v.T, + ret_from_gt_np=u @ np.diag(s) @ v.T, + atol=1e-04, + backend=backend_fw, + ground_truth_backend=frontend, + ) + else: + helpers.assert_all_close( + ret_np=frontend_u[..., : frontend_s.shape[0]] + @ np.diag(frontend_s) + @ frontend_v.T, + ret_from_gt_np=u[..., : s.shape[0]] @ np.diag(s) @ v.T, + atol=1e-04, + backend=backend_fw, + ground_truth_backend=frontend, + ) @handle_frontend_test( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index cda403aa5309..6710067a2678 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -1242,7 +1242,12 @@ def test_torch_solve_ex( # svd @handle_frontend_test( fn_tree="torch.linalg.svd", - dtype_and_x=_get_dtype_and_matrix(square=True), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)), + ), full_matrices=st.booleans(), ) def test_torch_svd( @@ -1257,7 +1262,7 @@ def test_torch_svd( ): dtype, x = dtype_and_x x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive definite beforehand + # make symmetric positive definite x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 ret, frontend_ret = helpers.test_frontend_function( input_dtypes=dtype, @@ -1267,25 +1272,36 @@ def test_torch_svd( fn_tree=fn_tree, on_device=on_device, test_values=False, - atol=1e-03, - rtol=1e-05, A=x, full_matrices=full_matrices, ) - ret = [ivy.to_numpy(x) for x in ret] + if backend_fw == "torch": + frontend_ret = [x.detach() for x in frontend_ret] + ret = [x.detach() for x in ret] + ret = [np.asarray(x) for x in ret] frontend_ret = [np.asarray(x) for x in frontend_ret] - u, s, vh = ret frontend_u, frontend_s, frontend_vh = frontend_ret - - assert_all_close( - ret_np=u @ np.diag(s) @ vh, - ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, - rtol=1e-2, - atol=1e-2, - ground_truth_backend=frontend, - backend=backend_fw, - ) + if full_matrices: + helpers.assert_all_close( + ret_np=( + frontend_u[..., : frontend_s.shape[0]] + @ np.diag(frontend_s) + @ frontend_vh + ), + ret_from_gt_np=u[..., : s.shape[0]] @ np.diag(s) @ vh, + atol=1e-04, + backend=backend_fw, + ground_truth_backend=frontend, + ) + else: + helpers.assert_all_close( + ret_np=(frontend_u @ np.diag(frontend_s) @ frontend_vh), + ret_from_gt_np=u @ np.diag(s) @ vh, + atol=1e-04, + backend=backend_fw, + ground_truth_backend=frontend, + ) # svdvals diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 1b4c8080695b..f7cd91bea4f8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -13095,13 +13095,13 @@ def test_torch_sum( on_device=on_device, ) - +# svd @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", method_name="svd", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), min_value=0, max_value=10, shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)), @@ -13122,7 +13122,8 @@ def test_torch_svd( ): input_dtype, x = dtype_and_x x = np.asarray(x[0], dtype=input_dtype[0]) - + # make symmetric positive-definite + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 ret, frontend_ret = helpers.test_frontend_method( init_input_dtypes=input_dtype, init_all_as_kwargs_np={ @@ -13141,28 +13142,33 @@ def test_torch_svd( on_device=on_device, test_values=False, ) - with helpers.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - - u, s, vh = ret - frontend_u, frontend_s, frontend_vh = frontend_ret - - if compute_uv: + ret = [np.asarray(x) for x in ret] + frontend_ret = [np.asarray(x.resolve_conj()) for x in frontend_ret] + u, s, v = ret + frontend_u, frontend_s, frontend_v = frontend_ret + if not compute_uv: helpers.assert_all_close( - ret_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T, - ret_from_gt_np=u @ np.diag(s) @ vh, - rtol=1e-2, - atol=1e-2, + ret_np=frontend_s, + ret_from_gt_np=s, + atol=1e-04, + backend=backend_fw, + ground_truth_backend=frontend, + ) + elif not some: + helpers.assert_all_close( + ret_np=frontend_u @ np.diag(frontend_s) @ frontend_v.T, + ret_from_gt_np=u @ np.diag(s) @ v.T, + atol=1e-04, backend=backend_fw, ground_truth_backend=frontend, ) else: helpers.assert_all_close( - ret_np=frontend_s, - ret_from_gt_np=s, - rtol=1e-2, - atol=1e-2, + ret_np=frontend_u[..., : frontend_s.shape[0]] + @ np.diag(frontend_s) + @ frontend_v.T, + ret_from_gt_np=u[..., : s.shape[0]] @ np.diag(s) @ v.T, + atol=1e-04, backend=backend_fw, ground_truth_backend=frontend, )