From 868f4a9e6a6b3d6cc7b2cc679af97b28aa0909be Mon Sep 17 00:00:00 2001 From: Daniel4078 <45633544+Daniel4078@users.noreply.github.com> Date: Wed, 19 Jun 2024 10:26:49 +0800 Subject: [PATCH 1/4] try casting output of ivy.cholesky to float64 to pass the test --- ivy/functional/frontends/torch/blas_and_lapack_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/torch/blas_and_lapack_ops.py b/ivy/functional/frontends/torch/blas_and_lapack_ops.py index a1c8cd9bbe77b..22796ca9a5cf0 100644 --- a/ivy/functional/frontends/torch/blas_and_lapack_ops.py +++ b/ivy/functional/frontends/torch/blas_and_lapack_ops.py @@ -91,7 +91,7 @@ def chain_matmul(*matrices, out=None): @to_ivy_arrays_and_back def cholesky(input, upper=False, *, out=None): - return ivy.cholesky(input, upper=upper, out=out) + return ivy.cholesky(input, upper=upper, out=out).astype(ivy.float64) @to_ivy_arrays_and_back From 7cf1108618194db2465873f34fa9582c473d5d47 Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Wed, 19 Jun 2024 13:31:50 +0800 Subject: [PATCH 2/4] added a condition just in case if input dtype is complex, in that case it should not be casted to float64 --- ivy/functional/frontends/torch/blas_and_lapack_ops.py | 5 ++++- ivy/functional/frontends/torch/linalg.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/torch/blas_and_lapack_ops.py b/ivy/functional/frontends/torch/blas_and_lapack_ops.py index 22796ca9a5cf0..443dc9cbc69eb 100644 --- a/ivy/functional/frontends/torch/blas_and_lapack_ops.py +++ b/ivy/functional/frontends/torch/blas_and_lapack_ops.py @@ -91,7 +91,10 @@ def chain_matmul(*matrices, out=None): @to_ivy_arrays_and_back def cholesky(input, upper=False, *, out=None): - return ivy.cholesky(input, upper=upper, out=out).astype(ivy.float64) + temp = ivy.cholesky(input, upper=upper, out=out) + if input.dtype == ivy.float32: + return temp.astype(ivy.float64) + return temp @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 3df44b0ae3fa7..0e0158aec3a56 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -12,7 +12,10 @@ {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" ) def cholesky(input, *, upper=False, out=None): - return ivy.cholesky(input, upper=upper, out=out) + temp = ivy.cholesky(input, upper=upper, out=out) + if input.dtype == ivy.float32: + return temp.astype(ivy.float64) + return temp @to_ivy_arrays_and_back From 80839791ccb07caf2ead5678a8ae72fe41514e19 Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Wed, 19 Jun 2024 20:15:43 +0800 Subject: [PATCH 3/4] change the cholesky test of blas_and_lapack_ops to be like the similar one in test_linalg to fix the problem --- ivy/functional/frontends/torch/blas_and_lapack_ops.py | 5 +---- ivy/functional/frontends/torch/linalg.py | 5 +---- .../test_frontends/test_torch/test_blas_and_lapack_ops.py | 7 +++---- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/ivy/functional/frontends/torch/blas_and_lapack_ops.py b/ivy/functional/frontends/torch/blas_and_lapack_ops.py index 443dc9cbc69eb..a1c8cd9bbe77b 100644 --- a/ivy/functional/frontends/torch/blas_and_lapack_ops.py +++ b/ivy/functional/frontends/torch/blas_and_lapack_ops.py @@ -91,10 +91,7 @@ def chain_matmul(*matrices, out=None): @to_ivy_arrays_and_back def cholesky(input, upper=False, *, out=None): - temp = ivy.cholesky(input, upper=upper, out=out) - if input.dtype == ivy.float32: - return temp.astype(ivy.float64) - return temp + return ivy.cholesky(input, upper=upper, out=out) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 0e0158aec3a56..3df44b0ae3fa7 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -12,10 +12,7 @@ {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" ) def cholesky(input, *, upper=False, out=None): - temp = ivy.cholesky(input, upper=upper, out=out) - if input.dtype == ivy.float32: - return temp.astype(ivy.float64) - return temp + return ivy.cholesky(input, upper=upper, out=out) @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py index 8e6a654d79456..f66853869c8e4 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py @@ -540,10 +540,9 @@ def test_torch_cholesky( backend_fw, ): dtype, x = dtype_and_x - x = x[0] - x = ( - np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - ) # make symmetric positive-definite + x = np.asarray(x[0], dtype=dtype[0]) + x = np.matmul(np.conjugate(x.T), x) + np.identity(x.shape[0], dtype=dtype[0]) + # make symmetric positive-definite helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, From d9b8187178782a851f81f72315d2780c945be15d Mon Sep 17 00:00:00 2001 From: Jin Wang Date: Wed, 19 Jun 2024 20:20:53 +0800 Subject: [PATCH 4/4] also include the fix to the test in test_linalg.py just in case --- ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 9686bdf183982..3fe907273c191 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -319,8 +319,8 @@ def test_torch_cholesky( ): dtype, x = dtype_and_x x = np.asarray(x[0], dtype=dtype[0]) - x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite - + x = np.matmul(np.conjugate(x.T), x) + np.identity(x.shape[0], dtype=dtype[0]) + # make symmetric positive-definite helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw,