diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index dd905e3b33d43..a15f0d7947ec2 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -69,6 +69,13 @@ const char *cusparseGetErrorString(cusparseStatus_t status); namespace at::cuda::solver { C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); + +constexpr const char* _cusolver_backend_suggestion = \ + "If you keep seeing this error, you may use " \ + "`torch.backends.cuda.preferred_linalg_library()` to try " \ + "linear algebra operators with other supported backends. " \ + "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; + } // namespace at::cuda::solver // When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. @@ -85,13 +92,15 @@ C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); "cusolver error: ", \ at::cuda::solver::cusolverGetErrorMessage(__err), \ ", when calling `" #EXPR "`", \ - ". This error may appear if the input matrix contains NaN."); \ + ". This error may appear if the input matrix contains NaN. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ } else { \ TORCH_CHECK( \ __err == CUSOLVER_STATUS_SUCCESS, \ "cusolver error: ", \ at::cuda::solver::cusolverGetErrorMessage(__err), \ - ", when calling `" #EXPR "`"); \ + ", when calling `" #EXPR "`. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ } \ } while (0) diff --git a/test/test_linalg.py b/test/test_linalg.py index 339c653992ec3..ad8dd36a0bdf0 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -988,6 +988,26 @@ def test_eigh_errors_and_warnings(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): torch.linalg.eigh(a, out=(out_w, out_v)) + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + @unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.") + def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype): + # See https://github.com/pytorch/pytorch/issues/94772, https://github.com/pytorch/pytorch/issues/105359 + # This test crashes with `cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED` on cuda 11.8, + # but passes on cuda 12.1 update 1 or later. + a = torch.ones(512, 512, dtype=dtype, device=device) + a[0, 0] = 1.0e-5 + a[-1, -1] = 1.0e5 + + eigh_out = torch.linalg.eigh(a) + svd_out = torch.linalg.svd(a) + + # Matrix input a is too ill-conditioned. + # We'll just compare the first two singular values/eigenvalues. They are 1.0e5 and 511.0 + # The precision override with tolerance of 1.0 makes sense since ill-conditioned inputs are hard to converge + # to exact values. + self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2) + self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2) @skipCUDAIfNoMagma @skipCPUIfNoLapack diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 6275a7912930e..51924318b1c21 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -647,6 +647,13 @@ :math:`\lambda_i` through the computation of :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. +.. warning:: User may see pytorch crashes if running `eigh` on CUDA devices with CUDA versions before 12.1 update 1 + with large ill-conditioned matrices as inputs. + Refer to :ref:`Linear Algebra Numerical Stability` for more details. + If this is the case, user may (1) tune their matrix inputs to be less ill-conditioned, + or (2) use :func:`torch.backends.cuda.preferred_linalg_library` to + try other supported backends. + .. seealso:: :func:`torch.linalg.eigvalsh` computes only the eigenvalues of a Hermitian matrix.