From e2a57077ad876248512d03f9ae1074c7f82f01a0 Mon Sep 17 00:00:00 2001 From: AliTarekk Date: Sun, 27 Aug 2023 03:58:09 +0000 Subject: [PATCH 1/3] Add linalg function `eigh` to PyTorch frontend --- ivy/functional/frontends/torch/linalg.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 46b8609b82012..d33d0b661d481 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -370,3 +370,9 @@ def vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None): return ivy.vector_norm( input, axis=dim, keepdims=keepdim, ord=ord, out=out, dtype=dtype ) + +@with_supported_dtypes( + {"2.0.1 and below": ("float32", "float64", "complex32", "complex64", "complex128")}, "torch", +) +def eigh(A, UPLO="L", *, out=None): + return ivy.eigh(A, UPLO=UPLO, out=out) \ No newline at end of file From 1687c9a6611076d8fbcdf58d49ae6b1764faa22a Mon Sep 17 00:00:00 2001 From: AliTarekk Date: Sun, 27 Aug 2023 04:02:58 +0000 Subject: [PATCH 2/3] Add test for `eigh` PyTorch frontend function Add test for PyTorch linear algebra frontend function `eigh` --- .../test_frontends/test_torch/test_linalg.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 496668ac8e255..aad6b049b614a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -1289,3 +1289,47 @@ def test_torch_cholesky_ex( input=x, upper=upper, ) + + +@handle_frontend_test( + fn_tree="torch.linalg.eigh", + dtype_and_x=_get_dtype_and_matrix(dtype="valid", square=True, invertible=True), + UPLO=st.sampled_from(("L", "U")), +) +def test_torch_eigh( + *, + dtype_and_x, + UPLO, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, x = dtype_and_x + x = np.array(x[0], dtype=dtype[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + a=x, + UPLO=UPLO, + ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + L, Q = ret + frontend_L, frontend_Q = frontend_ret + + assert_all_close( + ret_np=Q @ np.diag(L) @ Q.T, + ret_from_gt_np=frontend_Q @ np.diag(frontend_L) @ frontend_Q.T, + atol=1e-02, + ) From 07f380d7cb38833de18de7c6c3d60c25066bcd85 Mon Sep 17 00:00:00 2001 From: AliTarekk Date: Sat, 9 Sep 2023 16:22:03 +0000 Subject: [PATCH 3/3] Fix linting issues --- ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index aad6b049b614a..e8aa7a8c9b733 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -1332,4 +1332,5 @@ def test_torch_eigh( ret_np=Q @ np.diag(L) @ Q.T, ret_from_gt_np=frontend_Q @ np.diag(frontend_L) @ frontend_Q.T, atol=1e-02, + )