diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index 974459f1e415..a1631e1af2f4 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -734,13 +734,15 @@ GEMM_EX(cublasSgemmEx, MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") GEMM_EX(cublasCgemmEx, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -GEMM_EX(cublasCgemm3mEx, "oneapi::mkl::blas::compute_mode::complex_3m") +GEMM_EX(cublasCgemm3mEx, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") GEMM_EX(cublasGemmEx, 17) GEMM_EX(cublasSgemmEx_64, MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") GEMM_EX(cublasCgemmEx_64, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -GEMM_EX(cublasCgemm3mEx_64, "oneapi::mkl::blas::compute_mode::complex_3m") +GEMM_EX(cublasCgemm3mEx_64, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") GEMM_EX(cublasGemmEx_64, 17) #undef GEMM_EX @@ -759,16 +761,20 @@ GEMM_EX(cublasGemmEx_64, 17) ARG(10), ARG(11), ARG(12), ARG(COMPUTE_TYPE))))) SYHERK(cublasCsyrkEx, false, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCsyrk3mEx, false, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCsyrk3mEx, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") SYHERK(cublasCherkEx, true, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCherk3mEx, true, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCherk3mEx, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") SYHERK(cublasCsyrkEx_64, false, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCsyrk3mEx_64, false, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCsyrk3mEx_64, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") SYHERK(cublasCherkEx_64, true, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCherk3mEx_64, true, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCherk3mEx_64, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") #undef SYHERK #define SYRK(NAME, TYPE, IS_COMPLEX) \ diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 4203d4ddc8c7..c5a6813d196e 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -1477,14 +1477,13 @@ deduce_compute_mode(std::optional ct, math_mode mm, /// \param [in] c_type Data type of the matrix C. /// \param [in] ldc Leading dimension of C. /// \param [in] ct Compute type. -inline void -gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n, - std::int64_t k, const void *alpha, const void *a, library_data_t a_type, - std::int64_t lda, const void *b, library_data_t b_type, std::int64_t ldb, - const void *beta, void *c, library_data_t c_type, std::int64_t ldc, - std::variant - ct) { +inline void gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n, + std::int64_t k, const void *alpha, const void *a, + library_data_t a_type, std::int64_t lda, const void *b, + library_data_t b_type, std::int64_t ldb, const void *beta, + void *c, library_data_t c_type, std::int64_t ldc, + std::variant ct) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); @@ -1989,13 +1988,12 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, /// \param [in] ldc Leading dimension of the matrix c. /// \param [in] ct Compute type. template -inline void syherk( - descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, - oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, - const void *alpha, const void *a, library_data_t a_type, std::int64_t lda, - const void *beta, void *c, library_data_t c_type, std::int64_t ldc, - std::variant - ct) { +inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, + const void *alpha, const void *a, library_data_t a_type, + std::int64_t lda, const void *beta, void *c, + library_data_t c_type, std::int64_t ldc, + std::variant ct) { sycl::queue q = desc_ptr->get_queue(); #ifdef __INTEL_MKL__ oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index 72d91153a513..840292ad2744 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -81,16 +81,16 @@ void foo4() { void *A, *B, *C; int lda, ldb, ldc; cudaDataType_t a_type, b_type, c_type; - // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, dpct::library_data_t::complex_float); cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc); cublasFillMode_t uplo; cublasOperation_t trans; float *alpha_s, *beta_s; // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); cublasCsyrkEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); cublasCsyrk3mEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); cublasCherkEx(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc); diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index e1498f4f2151..0be9808ce43f 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -519,7 +519,7 @@ void foo() { // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); @@ -528,9 +528,9 @@ void foo() { cublasOperation_t trans; // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index cbde8ee81446..49372405d899 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -521,7 +521,7 @@ void foo() { // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); @@ -530,9 +530,9 @@ void foo() { cublasOperation_t trans; // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc);