From e136e615002e36d1eee284d637a5817912eafcbb Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 26 Aug 2024 11:07:47 +0800 Subject: [PATCH 01/21] WIP Signed-off-by: Jiang, Zhiwei --- .../dpct-rt/include/dpct/blas_utils.hpp | 108 ++++++++++-------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index c2da28a5f9bf..2a5d5dca9e8f 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -381,8 +381,8 @@ class working_memory { #endif template -inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx, - void *result) { +inline void nrm2_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, void *result) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); @@ -1292,51 +1292,6 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[], #endif } -/// Computes the Euclidean norm of a vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, void *result, library_data_t result_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, result_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::nrm2_impl(q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::nrm2_impl(q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::nrm2_impl, float>( - q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::nrm2_impl, double>( - q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::nrm2_impl( - q, n, x, incx, result); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - /// Computes the dot product of two vectors. /// \param [in] q The queue where the routine should be executed. /// \param [in] n Number of elements in vector x. @@ -2280,6 +2235,50 @@ inline void trmm(descriptor_ptr desc_ptr, oneapi::mkl::side left_right, data_c, ldc DPCT_COMPUTE_MODE_ARG); } +/// Computes the Euclidean norm of a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void nrm2(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type, result_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::nrm2_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::nrm2_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + detail::nrm2_impl, float>(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + detail::nrm2_impl, double>(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + detail::nrm2_impl(q, n, x, incx, result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + /// Finds the least squares solutions for a batch of overdetermined linear /// systems. Uses the QR factorization to solve a grouped batch of linear /// systems with full rank matrices. @@ -2580,6 +2579,21 @@ trmm(sycl::queue &q, oneapi::mkl::side left_right, blas::trmm(&desc, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, ldb, c, ldc); } + +/// Computes the Euclidean norm of a vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, + int incx, void *result, library_data_t result_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::nrm2(&desc, n, x, x_type, incx, result, result_type); +} } // namespace dpct #undef DPCT_COMPUTE_MODE_ARG #undef DPCT_COMPUTE_MODE_PARAM From d7fa71d1c1d457cc8109d350778e8cb45e2eaecd Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 26 Aug 2024 15:17:39 +0800 Subject: [PATCH 02/21] Update Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/APINamesCUBLAS.inc | 78 ++- clang/lib/DPCT/APINames_cuBLAS.inc | 12 +- clang/lib/DPCT/ASTTraversal.cpp | 12 +- clang/lib/DPCT/MapNames.cpp | 92 ++- .../dpct-rt/include/dpct/blas_utils.hpp | 592 +++++++++++------- clang/test/dpct/cublas_64.cu | 23 + clang/test/dpct/cublas_64_usm.cu | 23 + 7 files changed, 533 insertions(+), 299 deletions(-) diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index aa4e2c543e1d..7ed63ac71f67 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -122,43 +122,85 @@ GEMM(cublasZgemm_v2_64, "std::complex", true) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY("cublasNrm2Ex", - CALL(MapNames::getLibraryHelperNamespace() + "nrm2", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) + CALL_FACTORY_ENTRY( + "cublasNrm2Ex", + CALL(MapNames::getLibraryHelperNamespace() + "blas::nrm2", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY( + "cublasNrm2Ex_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::nrm2", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasDotEx", - CALL(MapNames::getLibraryHelperNamespace() + "dot", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9))))) + CALL(MapNames::getLibraryHelperNamespace() + "blas::dot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasDotEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::dot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasDotcEx", - CALL(MapNames::getLibraryHelperNamespace() + "dotc", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9))))) + CALL(MapNames::getLibraryHelperNamespace() + + "blas::dotc", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasDotcEx_64", + CALL(MapNames::getLibraryHelperNamespace() + + "blas::dotc", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY("cublasScalEx", - CALL(MapNames::getLibraryHelperNamespace() + "scal", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) + CALL_FACTORY_ENTRY( + "cublasScalEx", + CALL(MapNames::getLibraryHelperNamespace() + "blas::scal", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY( + "cublasScalEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::scal", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasAxpyEx", - CALL(MapNames::getLibraryHelperNamespace() + "axpy", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9))))) + CALL(MapNames::getLibraryHelperNamespace() + + "blas::axpy", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasAxpyEx_64", + CALL(MapNames::getLibraryHelperNamespace() + + "blas::axpy", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasRotEx", - CALL(MapNames::getLibraryHelperNamespace() + "rot", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9), ARG(10))))) + CALL(MapNames::getLibraryHelperNamespace() + "blas::rot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9), ARG(10))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasRotEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::rot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9), ARG(10))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, diff --git a/clang/lib/DPCT/APINames_cuBLAS.inc b/clang/lib/DPCT/APINames_cuBLAS.inc index 6a60861a5d48..312c04a7f91d 100644 --- a/clang/lib/DPCT/APINames_cuBLAS.inc +++ b/clang/lib/DPCT/APINames_cuBLAS.inc @@ -745,17 +745,17 @@ ENTRY(cublasSetVectorAsync_64, cublasSetVectorAsync_64, true, NO_FLAG, P4, "DPCT ENTRY(cublasGetVectorAsync_64, cublasGetVectorAsync_64, true, NO_FLAG, P4, "DPCT1018/DPCT1020") ENTRY(cublasSetMatrixAsync_64, cublasSetMatrixAsync_64, true, NO_FLAG, P4, "DPCT1018/DPCT1020") ENTRY(cublasGetMatrixAsync_64, cublasGetMatrixAsync_64, true, NO_FLAG, P4, "DPCT1018/DPCT1020") -ENTRY(cublasNrm2Ex_64, cublasNrm2Ex_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasDotEx_64, cublasDotEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasDotcEx_64, cublasDotcEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasScalEx_64, cublasScalEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasAxpyEx_64, cublasAxpyEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasNrm2Ex_64, cublasNrm2Ex_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasDotEx_64, cublasDotEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasDotcEx_64, cublasDotcEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasScalEx_64, cublasScalEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasAxpyEx_64, cublasAxpyEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCopyEx_64, cublasCopyEx_64, false, NO_FLAG, P4, "comment") ENTRY(cublasSwapEx_64, cublasSwapEx_64, false, NO_FLAG, P4, "comment") ENTRY(cublasIamaxEx_64, cublasIamaxEx_64, false, NO_FLAG, P4, "comment") ENTRY(cublasIaminEx_64, cublasIaminEx_64, false, NO_FLAG, P4, "comment") ENTRY(cublasAsumEx_64, cublasAsumEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasRotEx_64, cublasRotEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasRotEx_64, cublasRotEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasRotmEx_64, cublasRotmEx_64, false, NO_FLAG, P4, "comment") ENTRY(cublasSgemvBatched_64, cublasSgemvBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasDgemvBatched_64, cublasDgemvBatched_64, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 655e9e5cf32d..204424eb0161 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -4135,11 +4135,13 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasSgeqrfBatched", "cublasDgeqrfBatched", "cublasCgeqrfBatched", "cublasZgeqrfBatched", "cublasSgelsBatched", "cublasDgelsBatched", "cublasCgelsBatched", "cublasZgelsBatched", "cublasGemmEx", - "cublasSgemmEx", "cublasCgemmEx", "cublasNrm2Ex", "cublasDotEx", - "cublasDotcEx", "cublasScalEx", "cublasAxpyEx", "cublasRotEx", - "cublasGemmBatchedEx", "cublasGemmStridedBatchedEx", "cublasSdgmm", - "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", "cublasSgeam", - "cublasDgeam", "cublasCgeam", "cublasZgeam", + "cublasSgemmEx", "cublasCgemmEx", "cublasNrm2Ex", "cublasNrm2Ex_64", + "cublasDotEx", "cublasDotEx_64", "cublasDotcEx", "cublasDotcEx_64", + "cublasScalEx", "cublasScalEx_64", "cublasAxpyEx", "cublasAxpyEx_64", + "cublasRotEx", "cublasRotEx_64", "cublasGemmBatchedEx", + "cublasGemmStridedBatchedEx", "cublasSdgmm", "cublasDdgmm", + "cublasCdgmm", "cublasZdgmm", "cublasSgeam", "cublasDgeam", + "cublasCgeam", "cublasZgeam", /*Legacy API*/ "cublasInit", "cublasShutdown", "cublasGetError", "cublasSetKernelStream", "cublasGetVersion", diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index 642d24463e20..1079f1ccca16 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -1902,12 +1902,18 @@ void MapNames::setExplicitNamespaceMap( "oneapi::mkl::blas::column_major::gemm_batch"}, {"cublasZgemmStridedBatched", "oneapi::mkl::blas::column_major::gemm_batch"}, - {"cublasNrm2Ex", getLibraryHelperNamespace() + "nrm2_ex"}, - {"cublasDotEx", getLibraryHelperNamespace() + "dot_ex"}, - {"cublasDotcEx", getLibraryHelperNamespace() + "dotc_ex"}, - {"cublasScalEx", getLibraryHelperNamespace() + "scal_ex"}, - {"cublasAxpyEx", getLibraryHelperNamespace() + "axpy_ex"}, - {"cublasRotEx", getLibraryHelperNamespace() + "rot_ex"}, + {"cublasNrm2Ex", getLibraryHelperNamespace() + "blas::nrm2"}, + {"cublasNrm2Ex_64", getLibraryHelperNamespace() + "blas::nrm2"}, + {"cublasDotEx", getLibraryHelperNamespace() + "blas::dot"}, + {"cublasDotEx_64", getLibraryHelperNamespace() + "blas::dot"}, + {"cublasDotcEx", getLibraryHelperNamespace() + "blas::dotc"}, + {"cublasDotcEx_64", getLibraryHelperNamespace() + "blas::dotc"}, + {"cublasScalEx", getLibraryHelperNamespace() + "blas::scal"}, + {"cublasScalEx_64", getLibraryHelperNamespace() + "blas::scal"}, + {"cublasAxpyEx", getLibraryHelperNamespace() + "blas::axpy"}, + {"cublasAxpyEx_64", getLibraryHelperNamespace() + "blas::axpy"}, + {"cublasRotEx", getLibraryHelperNamespace() + "blas::rot"}, + {"cublasRotEx_64", getLibraryHelperNamespace() + "blas::rot"}, {"cublasGemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasSgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasCgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, @@ -1940,26 +1946,46 @@ void MapNames::setExplicitNamespaceMap( {"cublasDtrmm_v2", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasCtrmm_v2", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasZtrmm_v2", getLibraryHelperNamespace() + "blas::trmm"}, - {"cublasSgetrfBatched", getLibraryHelperNamespace() + "getrf_batch_wrapper"}, - {"cublasDgetrfBatched", getLibraryHelperNamespace() + "getrf_batch_wrapper"}, - {"cublasCgetrfBatched", getLibraryHelperNamespace() + "getrf_batch_wrapper"}, - {"cublasZgetrfBatched", getLibraryHelperNamespace() + "getrf_batch_wrapper"}, - {"cublasSgetrsBatched", getLibraryHelperNamespace() + "getrs_batch_wrapper"}, - {"cublasDgetrsBatched", getLibraryHelperNamespace() + "getrs_batch_wrapper"}, - {"cublasCgetrsBatched", getLibraryHelperNamespace() + "getrs_batch_wrapper"}, - {"cublasZgetrsBatched", getLibraryHelperNamespace() + "getrs_batch_wrapper"}, - {"cublasSgetriBatched", getLibraryHelperNamespace() + "getri_batch_wrapper"}, - {"cublasDgetriBatched", getLibraryHelperNamespace() + "getri_batch_wrapper"}, - {"cublasCgetriBatched", getLibraryHelperNamespace() + "getri_batch_wrapper"}, - {"cublasZgetriBatched", getLibraryHelperNamespace() + "getri_batch_wrapper"}, - {"cublasSgeqrfBatched", getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, - {"cublasDgeqrfBatched", getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, - {"cublasCgeqrfBatched", getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, - {"cublasZgeqrfBatched", getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, - {"cublasSgelsBatched", getLibraryHelperNamespace() + "gels_batch_wrapper"}, - {"cublasDgelsBatched", getLibraryHelperNamespace() + "gels_batch_wrapper"}, - {"cublasCgelsBatched", getLibraryHelperNamespace() + "gels_batch_wrapper"}, - {"cublasZgelsBatched", getLibraryHelperNamespace() + "gels_batch_wrapper"}, + {"cublasSgetrfBatched", + getLibraryHelperNamespace() + "getrf_batch_wrapper"}, + {"cublasDgetrfBatched", + getLibraryHelperNamespace() + "getrf_batch_wrapper"}, + {"cublasCgetrfBatched", + getLibraryHelperNamespace() + "getrf_batch_wrapper"}, + {"cublasZgetrfBatched", + getLibraryHelperNamespace() + "getrf_batch_wrapper"}, + {"cublasSgetrsBatched", + getLibraryHelperNamespace() + "getrs_batch_wrapper"}, + {"cublasDgetrsBatched", + getLibraryHelperNamespace() + "getrs_batch_wrapper"}, + {"cublasCgetrsBatched", + getLibraryHelperNamespace() + "getrs_batch_wrapper"}, + {"cublasZgetrsBatched", + getLibraryHelperNamespace() + "getrs_batch_wrapper"}, + {"cublasSgetriBatched", + getLibraryHelperNamespace() + "getri_batch_wrapper"}, + {"cublasDgetriBatched", + getLibraryHelperNamespace() + "getri_batch_wrapper"}, + {"cublasCgetriBatched", + getLibraryHelperNamespace() + "getri_batch_wrapper"}, + {"cublasZgetriBatched", + getLibraryHelperNamespace() + "getri_batch_wrapper"}, + {"cublasSgeqrfBatched", + getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, + {"cublasDgeqrfBatched", + getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, + {"cublasCgeqrfBatched", + getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, + {"cublasZgeqrfBatched", + getLibraryHelperNamespace() + "geqrf_batch_wrapper"}, + {"cublasSgelsBatched", + getLibraryHelperNamespace() + "gels_batch_wrapper"}, + {"cublasDgelsBatched", + getLibraryHelperNamespace() + "gels_batch_wrapper"}, + {"cublasCgelsBatched", + getLibraryHelperNamespace() + "gels_batch_wrapper"}, + {"cublasZgelsBatched", + getLibraryHelperNamespace() + "gels_batch_wrapper"}, {"cublasGetStatusString", ""}, {"cublasCgemm3m", "oneapi::mkl::blas::column_major::gemm"}, {"cublasZgemm3m", "oneapi::mkl::blas::column_major::gemm"}, @@ -2169,12 +2195,13 @@ void MapNames::setExplicitNamespaceMap( {"cublasCtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasZtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, // cublasLt - {"cublasLtCreate", - "new " + getLibraryHelperNamespace() + "blas_gemm::experimental::descriptor"}, - {"cublasLtDestroy", - "delete " + getLibraryHelperNamespace() + "blas_gemm::experimental::descriptor"}, + {"cublasLtCreate", "new " + getLibraryHelperNamespace() + + "blas_gemm::experimental::descriptor"}, + {"cublasLtDestroy", "delete " + getLibraryHelperNamespace() + + "blas_gemm::experimental::descriptor"}, {"cublasLtMatmulDescCreate", - "new " + getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t"}, + "new " + getLibraryHelperNamespace() + + "blas_gemm::experimental::matmul_desc_t"}, {"cublasLtMatmulDescDestroy", "delete " + getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t"}, @@ -2216,7 +2243,8 @@ void MapNames::setExplicitNamespaceMap( getLibraryHelperNamespace() + "blas_gemm::experimental::transform_desc_t::get_attribute"}, {"cublasLtMatrixTransform", - getLibraryHelperNamespace() + "blas_gemm::experimental::matrix_transform"}, + getLibraryHelperNamespace() + + "blas_gemm::experimental::matrix_transform"}, {"cublasLtGetVersion", getLibraryHelperNamespace() + "dnnl::get_version"}, }; diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 2a5d5dca9e8f..a06a41b7c408 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -403,8 +403,9 @@ inline void nrm2_impl(sycl::queue &q, std::int64_t n, const void *x, } template -inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx, - const Txy *y, int incy, Tr *result) { +inline void dotuc_impl(sycl::queue &q, std::int64_t n, const Txy *x, + std::int64_t incx, const Txy *y, std::int64_t incy, + Tr *result) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); @@ -431,41 +432,45 @@ inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx, if constexpr (std::is_same_v> || std::is_same_v>) { if constexpr (is_conjugate) - oneapi::mkl::blas::column_major::dotc(q, n, x, incx, y, incy, res_mem.get_ptr()); + oneapi::mkl::blas::column_major::dotc(q, n, x, incx, y, incy, + res_mem.get_ptr()); else - oneapi::mkl::blas::column_major::dotu(q, n, x, incx, y, incy, res_mem.get_ptr()); + oneapi::mkl::blas::column_major::dotu(q, n, x, incx, y, incy, + res_mem.get_ptr()); } else - oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, res_mem.get_ptr()); + oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, + res_mem.get_ptr()); #endif #endif } template -inline void dotuc(sycl::queue &q, int n, const void *x, - library_data_t x_type, int incx, const void *y, - library_data_t y_type, int incy, void *result, - library_data_t result_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, y_type, result_type); +inline void dotuc(sycl::queue &q, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, const void *y, + library_data_t y_type, std::int64_t incy, void *result, + library_data_t result_type) { + std::uint64_t key = + detail::get_type_combination_id(x_type, y_type, result_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float): { - detail::dotuc_impl( - q, n, reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + detail::dotuc_impl(q, n, reinterpret_cast(x), + incx, reinterpret_cast(y), + incy, reinterpret_cast(result)); break; } - case detail::get_type_combination_id(library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double): { - detail::dotuc_impl( - q, n, reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + detail::dotuc_impl(q, n, reinterpret_cast(x), + incx, reinterpret_cast(y), + incy, reinterpret_cast(result)); break; } case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float, - library_data_t::complex_float): { + library_data_t::complex_float, + library_data_t::complex_float): { detail::dotuc_impl( q, n, reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy, @@ -473,16 +478,17 @@ inline void dotuc(sycl::queue &q, int n, const void *x, break; } case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double, - library_data_t::complex_double): { + library_data_t::complex_double, + library_data_t::complex_double): { detail::dotuc_impl( q, n, reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy, reinterpret_cast *>(result)); break; } - case detail::get_type_combination_id(library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half): { + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half): { detail::dotuc_impl( q, n, reinterpret_cast(x), incx, reinterpret_cast(y), incy, @@ -495,22 +501,22 @@ inline void dotuc(sycl::queue &q, int n, const void *x, } template -inline void scal_impl(sycl::queue &q, int n, const void *alpha, void *x, - int incx) { +inline void scal_impl(sycl::queue &q, std::int64_t n, const void *alpha, + void *x, std::int64_t incx) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); #else Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); auto data_x = get_memory(x); - oneapi::mkl::blas::column_major::scal(q, n, alpha_val, - data_x, incx); + oneapi::mkl::blas::column_major::scal(q, n, alpha_val, data_x, incx); #endif } template -inline void axpy_impl(sycl::queue &q, int n, const void *alpha, const void *x, - int incx, void *y, int incy) { +inline void axpy_impl(sycl::queue &q, std::int64_t n, const void *alpha, + const void *x, std::int64_t incx, void *y, + std::int64_t incy) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); @@ -518,15 +524,14 @@ inline void axpy_impl(sycl::queue &q, int n, const void *alpha, const void *x, Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); auto data_x = get_memory(x); auto data_y = get_memory(y); - oneapi::mkl::blas::column_major::axpy(q, n, alpha_val, - data_x, incx, - data_y, incy); + oneapi::mkl::blas::column_major::axpy(q, n, alpha_val, data_x, incx, data_y, + incy); #endif } template -inline void rot_impl(sycl::queue &q, int n, void *x, int incx, void *y, - int incy, const void *c, const void *s) { +inline void rot_impl(sycl::queue &q, std::int64_t n, void *x, std::int64_t incx, + void *y, std::int64_t incy, const void *c, const void *s) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); @@ -535,9 +540,8 @@ inline void rot_impl(sycl::queue &q, int n, void *x, int incx, void *y, Ts s_value = dpct::get_value(reinterpret_cast(s), q); auto data_x = get_memory(x); auto data_y = get_memory(y); - oneapi::mkl::blas::column_major::rot(q, n, data_x, incx, - data_y, incy, c_value, - s_value); + oneapi::mkl::blas::column_major::rot(q, n, data_x, incx, data_y, incy, + c_value, s_value); #endif } @@ -1292,200 +1296,6 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[], #endif } -/// Computes the dot product of two vectors. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in] y Input vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void dot(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, const void *y, library_data_t y_type, int incy, - void *result, library_data_t result_type) { - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); -} - -/// Computes the dot product of two vectors, conjugating the first vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in] y Input vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void dotc(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, const void *y, library_data_t y_type, int incy, - void *result, library_data_t result_type) { - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); -} - -/// Computes the product of a vector by a scalar. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] alpha The scale factor alpha. -/// \param [in] alpha_type The data type of alpha. -/// \param [in, out] x Input/Output vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -inline void scal(sycl::queue &q, int n, const void *alpha, - library_data_t alpha_type, void *x, library_data_t x_type, - int incx) { - std::uint64_t key = detail::get_type_combination_id(x_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float): { - detail::scal_impl(q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::real_double): { - detail::scal_impl(q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float): { - detail::scal_impl, std::complex>(q, n, alpha, - x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double): { - detail::scal_impl, std::complex>( - q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::real_half): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - sycl::half alaph_half(alpha_value); - detail::scal_impl(q, n, &alaph_half, x, incx); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Computes a vector-scalar product and adds the result to a vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] alpha The scale factor alpha. -/// \param [in] alpha_type The data type of alpha. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in, out] y Input/Output vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -inline void axpy(sycl::queue &q, int n, const void *alpha, - library_data_t alpha_type, const void *x, library_data_t x_type, - int incx, void *y, library_data_t y_type, int incy) { - std::uint64_t key = detail::get_type_combination_id(x_type, alpha_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::axpy_impl, std::complex>( - q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::axpy_impl, std::complex>( - q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_float): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - sycl::half alaph_half(alpha_value); - detail::axpy_impl(q, n, &alaph_half, x, incx, y, incy); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Performs rotation of points in the plane. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in, out] x Input/Output vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in, out] y Input/Output vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [in] c Scaling factor. -/// \param [in] s Scaling factor. -/// \param [in] cs_type Data type of the scaling factors. -inline void rot(sycl::queue &q, int n, void *x, library_data_t x_type, - int incx, void *y, library_data_t y_type, int incy, - const void *c, const void *s, library_data_t cs_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::rot_impl, float, float>(q, n, x, incx, y, incy, c, - s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::rot_impl, double, double>(q, n, x, incx, y, incy, c, - s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::rot_impl, float, std::complex>(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::rot_impl, double, std::complex>(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_bfloat16, - library_data_t::real_bfloat16): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - namespace blas { namespace detail { inline library_data_t compute_type_to_library_data_t(compute_type ct) { @@ -2279,6 +2089,214 @@ inline void nrm2(descriptor_ptr desc_ptr, std::int64_t n, const void *x, } } +/// Computes the dot product of two vectors. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void dot(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, const void *y, + library_data_t y_type, std::int64_t incy, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); +} + +/// Computes the dot product of two vectors, conjugating the first vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void dotc(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, const void *y, + library_data_t y_type, std::int64_t incy, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); +} + +/// Computes the product of a vector by a scalar. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +inline void scal(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, + library_data_t alpha_type, void *x, library_data_t x_type, + std::int64_t incx) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float): { + detail::scal_impl(q, n, alpha, x, incx); + break; + } + case detail::get_type_combination_id(library_data_t::real_double): { + detail::scal_impl(q, n, alpha, x, incx); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float): { + detail::scal_impl, std::complex>(q, n, alpha, x, + incx); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double): { + detail::scal_impl, std::complex>(q, n, alpha, + x, incx); + break; + } + case detail::get_type_combination_id(library_data_t::real_half): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + sycl::half alaph_half(alpha_value); + detail::scal_impl(q, n, &alaph_half, x, incx); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes a vector-scalar product and adds the result to a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +inline void axpy(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, + library_data_t alpha_type, const void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type, alpha_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::axpy_impl(q, n, alpha, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::axpy_impl(q, n, alpha, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + detail::axpy_impl, std::complex>(q, n, alpha, x, + incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double): { + detail::axpy_impl, std::complex>( + q, n, alpha, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + sycl::half alaph_half(alpha_value); + detail::axpy_impl(q, n, &alaph_half, x, incx, y, + incy); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Performs rotation of points in the plane. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [in] c Scaling factor. +/// \param [in] s Scaling factor. +/// \param [in] cs_type Data type of the scaling factors. +inline void rot(descriptor_ptr desc_ptr, std::int64_t n, void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy, const void *c, + const void *s, library_data_t cs_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + detail::rot_impl, float, float>(q, n, x, incx, y, incy, + c, s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + detail::rot_impl, double, double>(q, n, x, incx, y, + incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + detail::rot_impl, float, std::complex>( + q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double): { + detail::rot_impl, double, std::complex>( + q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + detail::rot_impl(q, n, x, incx, y, incy, + c, s); + break; + } + case detail::get_type_combination_id(library_data_t::real_bfloat16, + library_data_t::real_bfloat16): { + detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + /// Finds the least squares solutions for a batch of overdetermined linear /// systems. Uses the QR factorization to solve a grouped batch of linear /// systems with full rank matrices. @@ -2588,12 +2606,110 @@ trmm(sycl::queue &q, oneapi::mkl::side left_right, /// \param [in] incx Stride of vector x. /// \param [out] result The result scalar. /// \param [in] result_type Data type of the result. -inline void nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, void *result, library_data_t result_type) { +[[deprecated("Please use dpct::blas::nrm2(...) instead.")]] inline void +nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, + void *result, library_data_t result_type) { blas::descriptor desc; desc.set_queue(&q); blas::nrm2(&desc, n, x, x_type, incx, result, result_type); } + +/// Computes the dot product of two vectors. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +[[deprecated("Please use dpct::blas::dot(...) instead.")]] inline void +dot(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, + const void *y, library_data_t y_type, int incy, void *result, + library_data_t result_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::dot(q, n, x, x_type, incx, y, y_type, incy, result, result_type); +} + +/// Computes the dot product of two vectors, conjugating the first vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +[[deprecated("Please use dpct::blas::dotc(...) instead.")]] inline void +dotc(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, + const void *y, library_data_t y_type, int incy, void *result, + library_data_t result_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::dotc(q, n, x, x_type, incx, y, y_type, incy, result, result_type); +} + +/// Computes the product of a vector by a scalar. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +[[deprecated("Please use dpct::blas::scal(...) instead.")]] inline void +scal(sycl::queue &q, int n, const void *alpha, library_data_t alpha_type, + void *x, library_data_t x_type, int incx) { + blas::descriptor desc; + desc.set_queue(&q); + blas::scal(q, n, alpha, alpha_type, x, x_type, incx); +} + +/// Computes a vector-scalar product and adds the result to a vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +[[deprecated("Please use dpct::blas::axpy(...) instead.")]] inline void +axpy(sycl::queue &q, int n, const void *alpha, library_data_t alpha_type, + const void *x, library_data_t x_type, int incx, void *y, + library_data_t y_type, int incy) { + blas::descriptor desc; + desc.set_queue(&q); + blas::axpy(q, n, alpha, alpha_type, x, x_type, incx, y, y_type, incy); +} + +/// Performs rotation of points in the plane. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [in] c Scaling factor. +/// \param [in] s Scaling factor. +/// \param [in] cs_type Data type of the scaling factors. +[[deprecated("Please use dpct::blas::rot(...) instead.")]] inline void +rot(sycl::queue &q, int n, void *x, library_data_t x_type, int incx, void *y, + library_data_t y_type, int incy, const void *c, const void *s, + library_data_t cs_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::rot(q, n, x, x_type, incx, y, y_type, incy, c, s, cs_type); +} } // namespace dpct #undef DPCT_COMPUTE_MODE_ARG #undef DPCT_COMPUTE_MODE_PARAM diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index 3684143426a0..adbc7ed7207f 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -474,4 +474,27 @@ void foo() { // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::herk(handle, uplo, transa, n, k, alpha_z, A_z, lda, B_z, ldb, beta_d, C_z, ldc)); status = cublasCherkx_64(handle, uplo, transa, n, k, alpha_c, A_c, lda, B_c, ldb, beta_s, C_c, ldc); status = cublasZherkx_64(handle, uplo, transa, n, k, alpha_z, A_z, lda, B_z, ldb, beta_d, C_z, ldc); + + cudaDataType type_x; + cudaDataType type_y; + cudaDataType type_res; + cudaDataType type_exec; + cudaDataType type_alpha; + cudaDataType type_cs; + void *res; + void *x; + void *y; + void *alpha; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::nrm2(handle, n, x, type_x, incx, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dot(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dotc(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::scal(handle, n, alpha, type_alpha, x, type_x, incx)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::axpy(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::rot(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs)); + status = cublasNrm2Ex_64(handle, n, x, type_x, incx, res, type_res, type_exec); + status = cublasDotEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasDotcEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasScalEx_64(handle, n, alpha, type_alpha, x, type_x, incx, type_exec); + status = cublasAxpyEx_64(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy, type_exec); + status = cublasRotEx_64(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs, type_exec); } diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index c73ecb197bad..bd517c63877f 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -474,4 +474,27 @@ void foo() { // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::herk(handle, uplo, transa, n, k, alpha_z, A_z, lda, B_z, ldb, beta_d, C_z, ldc)); status = cublasCherkx_64(handle, uplo, transa, n, k, alpha_c, A_c, lda, B_c, ldb, beta_s, C_c, ldc); status = cublasZherkx_64(handle, uplo, transa, n, k, alpha_z, A_z, lda, B_z, ldb, beta_d, C_z, ldc); + + cudaDataType type_x; + cudaDataType type_y; + cudaDataType type_res; + cudaDataType type_exec; + cudaDataType type_alpha; + cudaDataType type_cs; + void *res; + void *x; + void *y; + void *alpha; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::nrm2(handle, n, x, type_x, incx, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dot(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dotc(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::scal(handle, n, alpha, type_alpha, x, type_x, incx)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::axpy(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::rot(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs)); + status = cublasNrm2Ex_64(handle, n, x, type_x, incx, res, type_res, type_exec); + status = cublasDotEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasDotcEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasScalEx_64(handle, n, alpha, type_alpha, x, type_x, incx, type_exec); + status = cublasAxpyEx_64(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy, type_exec); + status = cublasRotEx_64(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs, type_exec); } From e51b181fa1e7fb3df9ce2aeb9bc03b733f859951 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Tue, 27 Aug 2024 08:33:34 +0800 Subject: [PATCH 03/21] Support 8 cublas 64-bit API migration Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/APINamesCUBLAS.inc | 27 +++++++++++++++++++ clang/lib/DPCT/APINames_cuBLAS.inc | 4 +-- clang/lib/DPCT/ASTTraversal.cpp | 6 +++-- clang/lib/DPCT/MapNames.cpp | 2 ++ clang/test/dpct/cublas-usm-11.cu | 12 ++++----- clang/test/dpct/cublas_64.cu | 18 +++++++++++++ clang/test/dpct/cublas_64_usm.cu | 20 ++++++++++++++ .../query_api_mapping/cuBLAS/blas_10_1.cu | 2 +- .../query_api_mapping/cuBLAS/blas_part1.cu | 2 +- .../query_api_mapping/cuBLAS/blas_part2.cu | 2 +- .../query_api_mapping/cuBLAS/blas_part3.cu | 2 +- .../query_api_mapping/cuBLAS/blas_part4.cu | 2 +- .../query_api_mapping/cuBLAS/blas_part8.cu | 2 +- 13 files changed, 85 insertions(+), 16 deletions(-) diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index a9155b497a57..b4c9820180bc 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -234,6 +234,23 @@ ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( DOUBLE_POINTER_CONST_CAST(makeLiteral("void"), ARG(14), BOOL(false), BOOL(false)), ARG(15), ARG(16), ARG(17), ARG(18))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY( + "cublasGemmBatchedEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm_batch", ARG(0), + BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), + BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), + ARG(3), ARG(4), ARG(5), ARG(6), + DOUBLE_POINTER_CONST_CAST(makeLiteral("void"), ARG(7), BOOL(true), + BOOL(false)), + ARG(8), ARG(9), + DOUBLE_POINTER_CONST_CAST(makeLiteral("void"), ARG(10), BOOL(true), + BOOL(false)), + ARG(11), ARG(12), ARG(13), + DOUBLE_POINTER_CONST_CAST(makeLiteral("void"), ARG(14), + BOOL(false), BOOL(false)), + ARG(15), ARG(16), ARG(17), ARG(18))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -245,6 +262,16 @@ ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18), ARG(19), ARG(20), ARG(21))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY( + "cublasGemmStridedBatchedEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm_batch", ARG(0), + BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), + BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), + ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), + ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), + ARG(18), ARG(19), ARG(20), ARG(21))))) #define SYRKX(FUNC) \ ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( \ diff --git a/clang/lib/DPCT/APINames_cuBLAS.inc b/clang/lib/DPCT/APINames_cuBLAS.inc index 312c04a7f91d..98e8fbae6b11 100644 --- a/clang/lib/DPCT/APINames_cuBLAS.inc +++ b/clang/lib/DPCT/APINames_cuBLAS.inc @@ -802,8 +802,8 @@ ENTRY(cublasDgemmStridedBatched_64, cublasDgemmStridedBatched_64, false, NO_FLAG ENTRY(cublasCgemmStridedBatched_64, cublasCgemmStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCgemm3mStridedBatched_64, cublasCgemm3mStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasZgemmStridedBatched_64, cublasZgemmStridedBatched_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasGemmBatchedEx_64, cublasGemmBatchedEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasGemmStridedBatchedEx_64, cublasGemmStridedBatchedEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasGemmBatchedEx_64, cublasGemmBatchedEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasGemmStridedBatchedEx_64, cublasGemmStridedBatchedEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasSgemmGroupedBatched, cublasSgemmGroupedBatched, false, NO_FLAG, P4, "comment") ENTRY(cublasSgemmGroupedBatched_64, cublasSgemmGroupedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasDgemmGroupedBatched, cublasDgemmGroupedBatched, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 204424eb0161..5b5f9cb4326e 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -4139,7 +4139,8 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasDotEx", "cublasDotEx_64", "cublasDotcEx", "cublasDotcEx_64", "cublasScalEx", "cublasScalEx_64", "cublasAxpyEx", "cublasAxpyEx_64", "cublasRotEx", "cublasRotEx_64", "cublasGemmBatchedEx", - "cublasGemmStridedBatchedEx", "cublasSdgmm", "cublasDdgmm", + "cublasGemmBatchedEx_64", "cublasGemmStridedBatchedEx", + "cublasGemmStridedBatchedEx_64", "cublasSdgmm", "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", "cublasSgeam", "cublasDgeam", "cublasCgeam", "cublasZgeam", /*Legacy API*/ @@ -4404,7 +4405,8 @@ void BLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) { FuncName == "cublasDgemmBatched" || FuncName == "cublasCgemmBatched" || FuncName == "cublasZgemmBatched" || FuncName == "cublasStrsmBatched" || FuncName == "cublasDtrsmBatched" || FuncName == "cublasCtrsmBatched" || - FuncName == "cublasZtrsmBatched" || FuncName == "cublasGemmBatchedEx")) { + FuncName == "cublasZtrsmBatched" || FuncName == "cublasGemmBatchedEx" || + FuncName == "cublasGemmBatchedEx_64")) { report(FuncNameBegin, Diagnostics::API_NOT_MIGRATED, false, FuncName); return; } diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index 1079f1ccca16..ec3f669c02ea 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -1918,7 +1918,9 @@ void MapNames::setExplicitNamespaceMap( {"cublasSgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasCgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasGemmBatchedEx", getLibraryHelperNamespace() + "blas::gemm_batch"}, + {"cublasGemmBatchedEx_64", getLibraryHelperNamespace() + "blas::gemm_batch"}, {"cublasGemmStridedBatchedEx", getLibraryHelperNamespace() + "blas::gemm_batch"}, + {"cublasGemmStridedBatchedEx_64", getLibraryHelperNamespace() + "blas::gemm_batch"}, {"cublasSsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, {"cublasDsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, {"cublasCsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index 11522b38c4f3..75c92f358e11 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -13,12 +13,12 @@ void foo1() { const void **a_array; const void **b_array; void **c_array; - //CHECK:dpct::nrm2(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); - //CHECK-NEXT:dpct::dot(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); - //CHECK-NEXT:dpct::dotc(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); - //CHECK-NEXT:dpct::scal(handle->get_queue(), 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1); - //CHECK-NEXT:dpct::axpy(handle->get_queue(), 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1); - //CHECK-NEXT:dpct::rot(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, cos, sin, dpct::library_data_t::real_float); + //CHECK:dpct::blas::nrm2(handle, 4, x, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); + //CHECK-NEXT:dpct::blas::dot(handle, 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); + //CHECK-NEXT:dpct::blas::dotc(handle, 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); + //CHECK-NEXT:dpct::blas::scal(handle, 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1); + //CHECK-NEXT:dpct::blas::axpy(handle, 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1); + //CHECK-NEXT:dpct::blas::rot(handle, 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, cos, sin, dpct::library_data_t::real_float); //CHECK-NEXT:dpct::blas::gemm(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, b, dpct::library_data_t::real_half, 4, beta, c, dpct::library_data_t::real_half, 4, dpct::compute_type::f16); //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a_array, dpct::library_data_t::real_half, 4, b_array, dpct::library_data_t::real_half, 4, beta, c_array, dpct::library_data_t::real_half, 4, 2, dpct::compute_type::f16); //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, 16, b, dpct::library_data_t::real_half, 4, 16, beta, c, dpct::library_data_t::real_half, 4, 16, 2, dpct::compute_type::f16); diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index 334d5f7067be..5f08020b6678 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -497,4 +497,22 @@ void foo() { status = cublasScalEx_64(handle, n, alpha, type_alpha, x, type_x, incx, type_exec); status = cublasAxpyEx_64(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy, type_exec); status = cublasRotEx_64(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs, type_exec); + + void **a_array; + void **b_array; + void **c_array; + void *aa; + void *bb; + void *cc; + cudaDataType type_a; + cudaDataType type_b; + cudaDataType type_c; + void *beta; + cublasGemmAlgo_t algo; + int64_t batch; + int64_t stride_a; + int64_t stride_b; + int64_t stride_c; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec)); + status = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec, algo); } diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index d496101deca5..57c4c5b1c280 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -497,4 +497,24 @@ void foo() { status = cublasScalEx_64(handle, n, alpha, type_alpha, x, type_x, incx, type_exec); status = cublasAxpyEx_64(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy, type_exec); status = cublasRotEx_64(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs, type_exec); + + void **a_array; + void **b_array; + void **c_array; + void *aa; + void *bb; + void *cc; + cudaDataType type_a; + cudaDataType type_b; + cudaDataType type_c; + void *beta; + cublasGemmAlgo_t algo; + int64_t batch; + int64_t stride_a; + int64_t stride_b; + int64_t stride_c; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, const_cast(a_array), type_a, lda, const_cast(b_array), type_b, ldb, beta, c_array, type_c, ldc, batch, type_exec)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec)); + status = cublasGemmBatchedEx(handle, transa, transb, m, n, k, alpha, a_array, type_a, lda, b_array, type_b, ldb, beta, c_array, type_c, ldc, batch, type_exec, algo); + status = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec, algo); } diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu index 487daa3ebb1d..9591816f9b9d 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu @@ -9,4 +9,4 @@ // cublasRotEx-NEXT: s /*const void **/, cstype /*cudaDataType*/, // cublasRotEx-NEXT: computetype /*cudaDataType*/); // cublasRotEx-NEXT: Is migrated to: -// cublasRotEx-NEXT: dpct::rot(handle->get_queue(), n, x, xtype, incx, y, ytype, incy, c, s, cstype); +// cublasRotEx-NEXT: dpct::blas::rot(handle, n, x, xtype, incx, y, ytype, incy, c, s, cstype); diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu index 6d7b54ae38e4..1054027cc400 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu @@ -86,7 +86,7 @@ // cublasAxpyEx-NEXT: ytype /*cudaDataType*/, incy /*int*/, // cublasAxpyEx-NEXT: computetype /*cudaDataType*/); // cublasAxpyEx-NEXT: Is migrated to: -// cublasAxpyEx-NEXT: dpct::axpy(handle->get_queue(), n, alpha, alphatype, x, xtype, incx, y, ytype, incy); +// cublasAxpyEx-NEXT: dpct::blas::axpy(handle, n, alpha, alphatype, x, xtype, incx, y, ytype, incy); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasDtpmv | FileCheck %s -check-prefix=cublasDtpmv // cublasDtpmv: CUDA API: diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu index 85bbb574d3ab..b646ba45df10 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu @@ -5,7 +5,7 @@ // cublasDotEx-NEXT: ytype /*cudaDataType*/, incy /*int*/, res /*void **/, // cublasDotEx-NEXT: restype /*cudaDataType*/, computetype /*cudaDataType*/); // cublasDotEx-NEXT: Is migrated to: -// cublasDotEx-NEXT: dpct::dot(handle->get_queue(), n, x, xtype, incx, y, ytype, incy, res, restype); +// cublasDotEx-NEXT: dpct::blas::dot(handle, n, x, xtype, incx, y, ytype, incy, res, restype); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasDtbmv | FileCheck %s -check-prefix=cublasDtbmv // cublasDtbmv: CUDA API: diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu index bbe98f8ce5f2..21f1ee6ac93a 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu @@ -15,7 +15,7 @@ // cublasDotcEx-NEXT: ytype /*cudaDataType*/, incy /*int*/, res /*void **/, // cublasDotcEx-NEXT: restype /*cudaDataType*/, computetype /*cudaDataType*/); // cublasDotcEx-NEXT: Is migrated to: -// cublasDotcEx-NEXT: dpct::dotc(handle->get_queue(), n, x, xtype, incx, y, ytype, incy, res, restype); +// cublasDotcEx-NEXT: dpct::blas::dotc(handle, n, x, xtype, incx, y, ytype, incy, res, restype); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasSsyr2 | FileCheck %s -check-prefix=cublasSsyr2 // cublasSsyr2: CUDA API: diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu index 6bad86f59749..5fdd590d8d79 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu @@ -236,4 +236,4 @@ // cublasNrm2Ex-NEXT: xtype /*cudaDataType*/, incx /*int*/, res /*void **/, // cublasNrm2Ex-NEXT: restype /*cudaDataType*/, computetype /*cudaDataType*/); // cublasNrm2Ex-NEXT: Is migrated to: -// cublasNrm2Ex-NEXT: dpct::nrm2(handle->get_queue(), n, x, xtype, incx, res, restype); +// cublasNrm2Ex-NEXT: dpct::blas::nrm2(handle, n, x, xtype, incx, res, restype); diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu index 04027172cb09..ddf0daaffbc6 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu @@ -16,7 +16,7 @@ // cublasScalEx-NEXT: alphatype /*cudaDataType*/, x /*void **/, xtype /*cudaDataType*/, // cublasScalEx-NEXT: incx /*int*/, computetype /*cudaDataType*/); // cublasScalEx-NEXT: Is migrated to: -// cublasScalEx-NEXT: dpct::scal(handle->get_queue(), n, alpha, alphatype, x, xtype, incx); +// cublasScalEx-NEXT: dpct::blas::scal(handle, n, alpha, alphatype, x, xtype, incx); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasDger | FileCheck %s -check-prefix=cublasDger // cublasDger: CUDA API: From 6ea8f54f62ddce027491aacdeba1daab358b4e1f Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Thu, 29 Aug 2024 15:03:06 +0800 Subject: [PATCH 04/21] Add more support Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/APINamesCUBLAS.inc | 46 ++++ clang/lib/DPCT/APINames_cuBLAS.inc | 24 +- clang/lib/DPCT/ASTTraversal.cpp | 5 +- clang/lib/DPCT/MapNames.cpp | 23 +- .../dpct-rt/include/dpct/blas_utils.hpp | 242 ++++++++++++++++++ 5 files changed, 324 insertions(+), 16 deletions(-) diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index b4c9820180bc..be89f1eaa64b 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -1674,6 +1674,52 @@ GEMM_BATCH(cublasCgemmStridedBatched, "std::complex", true) GEMM_BATCH(cublasZgemmStridedBatched, "std::complex", true) #undef GEMM_BATCH +#define COPY_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, \ + CALL(MapNames::getLibraryHelperNamespace() + "blas::copy", ARG(0), \ + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7)))) +COPY_EX(cublasCopyEx) +COPY_EX(cublasCopyEx_64) +#undef COPY_EX + +#define SWAP_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, \ + CALL(MapNames::getLibraryHelperNamespace() + "blas::swap", ARG(0), \ + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7)))) +SWAP_EX(cublasSwapEx) +SWAP_EX(cublasSwapEx_64) +#undef SWAP_EX + +#define ASUM_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::asum", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6)))) +ASUM_EX(cublasAsumEx) +ASUM_EX(cublasAsumEx_64) +#undef ASUM_EX + +#define ROTM_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::rotm", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), \ + ARG(7), ARG(8), ARG(9)))) +ROTM_EX(cublasRotmEx) +ROTM_EX(cublasRotmEx_64) +#undef ROTM_EX + +#define IAMAXMIN_EX(NAME, IS_MAX) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::iamaxmin<" + \ + #IS_MAX + ">", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5)))) +IAMAXMIN_EX(cublasIamaxEx, true) +IAMAXMIN_EX(cublasIamaxEx_64, true) +IAMAXMIN_EX(cublasIaminEx, false) +IAMAXMIN_EX(cublasIaminEx_64, false) +#undef IAMAXMIN_EX + ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY( "cublasLtCreate", DEREF(0), NEW(MapNames::getLibraryHelperNamespace() + "blas_gemm::experimental::descriptor"))) diff --git a/clang/lib/DPCT/APINames_cuBLAS.inc b/clang/lib/DPCT/APINames_cuBLAS.inc index 98e8fbae6b11..2873c23b589e 100644 --- a/clang/lib/DPCT/APINames_cuBLAS.inc +++ b/clang/lib/DPCT/APINames_cuBLAS.inc @@ -525,19 +525,19 @@ ENTRY(cublasDtrmm, cublasDtrmm, true, NO_FLAG, P4, "Successful") ENTRY(cublasCtrmm, cublasCtrmm, true, NO_FLAG, P4, "Successful") ENTRY(cublasZtrmm, cublasZtrmm, true, NO_FLAG, P4, "Successful") -ENTRY(cublasAsumEx, cublasAsumEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasAsumEx, cublasAsumEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCgemm3mBatched, cublasCgemm3mBatched, false, NO_FLAG, P4, "comment") ENTRY(cublasCgemm3mEx, cublasCgemm3mEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCopyEx, cublasCopyEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasCopyEx, cublasCopyEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDotcEx, cublasDotcEx, true, NO_FLAG, P4, "Successful") ENTRY(cublasGetCudartVersion, cublasGetCudartVersion, false, NO_FLAG, P4, "comment") -ENTRY(cublasIamaxEx, cublasIamaxEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasIaminEx, cublasIaminEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasIamaxEx, cublasIamaxEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasIaminEx, cublasIaminEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasRotEx, cublasRotEx, true, NO_FLAG, P4, "Successful") ENTRY(cublasRotgEx, cublasRotgEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasRotmEx, cublasRotmEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasRotmEx, cublasRotmEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasRotmgEx, cublasRotmgEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasSwapEx, cublasSwapEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasSwapEx, cublasSwapEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasUint8gemmBias, cublasUint8gemmBias, false, NO_FLAG, P4, "comment") ENTRY(cublasXerbla, cublasXerbla, false, NO_FLAG, P4, "comment") ENTRY(cublasXtGetNumBoards, cublasXtGetNumBoards, false, NO_FLAG, P4, "comment") @@ -750,13 +750,13 @@ ENTRY(cublasDotEx_64, cublasDotEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDotcEx_64, cublasDotcEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasScalEx_64, cublasScalEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasAxpyEx_64, cublasAxpyEx_64, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasCopyEx_64, cublasCopyEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasSwapEx_64, cublasSwapEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasIamaxEx_64, cublasIamaxEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasIaminEx_64, cublasIaminEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasAsumEx_64, cublasAsumEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasCopyEx_64, cublasCopyEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasSwapEx_64, cublasSwapEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasIamaxEx_64, cublasIamaxEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasIaminEx_64, cublasIaminEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasAsumEx_64, cublasAsumEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasRotEx_64, cublasRotEx_64, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasRotmEx_64, cublasRotmEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasRotmEx_64, cublasRotmEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasSgemvBatched_64, cublasSgemvBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasDgemvBatched_64, cublasDgemvBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCgemvBatched_64, cublasCgemvBatched_64, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 89c8800214cc..a6c12481fa61 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -4175,7 +4175,10 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasGemmBatchedEx_64", "cublasGemmStridedBatchedEx", "cublasGemmStridedBatchedEx_64", "cublasSdgmm", "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", "cublasSgeam", "cublasDgeam", - "cublasCgeam", "cublasZgeam", + "cublasCgeam", "cublasZgeam", "cublasCopyEx", "cublasSwapEx", + "cublasIamaxEx", "cublasIaminEx", "cublasAsumEx", "cublasRotmEx", + "cublasCopyEx_64", "cublasSwapEx_64", "cublasIamaxEx_64", + "cublasIaminEx_64", "cublasAsumEx_64", "cublasRotmEx_64", /*Legacy API*/ "cublasInit", "cublasShutdown", "cublasGetError", "cublasSetKernelStream", "cublasGetVersion", diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index 10ddfbc544da..2ce96af5f730 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -2171,9 +2171,12 @@ void MapNames::setExplicitNamespaceMap( {"cublasSgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasCgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasGemmBatchedEx", getLibraryHelperNamespace() + "blas::gemm_batch"}, - {"cublasGemmBatchedEx_64", getLibraryHelperNamespace() + "blas::gemm_batch"}, - {"cublasGemmStridedBatchedEx", getLibraryHelperNamespace() + "blas::gemm_batch"}, - {"cublasGemmStridedBatchedEx_64", getLibraryHelperNamespace() + "blas::gemm_batch"}, + {"cublasGemmBatchedEx_64", + getLibraryHelperNamespace() + "blas::gemm_batch"}, + {"cublasGemmStridedBatchedEx", + getLibraryHelperNamespace() + "blas::gemm_batch"}, + {"cublasGemmStridedBatchedEx_64", + getLibraryHelperNamespace() + "blas::gemm_batch"}, {"cublasSsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, {"cublasDsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, {"cublasCsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, @@ -2449,6 +2452,20 @@ void MapNames::setExplicitNamespaceMap( {"cublasDtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasCtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasZtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, + {"cublasCopyEx", getLibraryHelperNamespace() + "blas::copy"}, + {"cublasSwapEx", getLibraryHelperNamespace() + "blas::swap"}, + {"cublasIamaxEx", getLibraryHelperNamespace() + "blas::iamaxmin"}, + {"cublasIaminEx", getLibraryHelperNamespace() + "blas::iamaxmin"}, + {"cublasAsumEx", getLibraryHelperNamespace() + "blas::asum"}, + {"cublasRotmEx", getLibraryHelperNamespace() + "blas::rotm"}, + {"cublasCopyEx_64", getLibraryHelperNamespace() + "blas::copy"}, + {"cublasSwapEx_64", getLibraryHelperNamespace() + "blas::swap"}, + {"cublasIamaxEx_64", + getLibraryHelperNamespace() + "blas::iamaxmin"}, + {"cublasIaminEx_64", + getLibraryHelperNamespace() + "blas::iamaxmin"}, + {"cublasAsumEx_64", getLibraryHelperNamespace() + "blas::asum"}, + {"cublasRotmEx_64", getLibraryHelperNamespace() + "blas::rotm"}, // cublasLt {"cublasLtCreate", "new " + getLibraryHelperNamespace() + "blas_gemm::experimental::descriptor"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index eaac27f845c8..48da9c3ad96a 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -548,6 +548,91 @@ inline void rot_impl(sycl::queue &q, std::int64_t n, void *x, std::int64_t incx, #endif } +template +inline void rotm_impl(sycl::queue &q, std::int64_t n, void *x, int64_t incx, + void *y, int64_t incy, const void *param) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + auto data_x = get_memory(x); + auto data_y = get_memory(y); + auto data_param = get_memory(param); + oneapi::mkl::blas::column_major::rotm(q, n, data_x, incx, data_y, incy, + data_param); +#endif +} + +template +inline void copy_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, void *y, std::int64_t incy) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::copy(q, n, data_x, incx, data_y, incy); +#endif +} + +template +inline void swap_impl(sycl::queue &q, std::int64_t n, void *x, + std::int64_t incx, void *y, std::int64_t incy) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::swap(q, n, data_x, incx, data_y, incy); +#endif +} + +template +inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, void *res) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + auto data_x = get_memory(x); + auto data_res = get_memory(res); + oneapi::mkl::blas::column_major::asum(q, n, data_x, incx, data_res); +#endif +} + +template +inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, void *res) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + auto data_x = get_memory(x); + auto data_res = get_memory(res); + oneapi::mkl::blas::column_major::asum(q, n, data_x, incx, data_res); +#endif +} + +template +inline void iamaxmin_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, std::int64_t *res) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + auto data_x = get_memory(x); + auto data_res = get_memory(res); + if constexpr (is_max) + oneapi::mkl::blas::column_major::iamax(q, n, data_x, incx, data_res, + oneapi::mkl::index_base::one); + else + oneapi::mkl::blas::column_major::iamin(q, n, data_x, incx, data_res, + oneapi::mkl::index_base::one); +#endif +} + #ifdef __INTEL_MKL__ template inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, @@ -2300,6 +2385,163 @@ inline void rot(descriptor_ptr desc_ptr, std::int64_t n, void *x, } } +inline void rotm(descriptor_ptr desc_ptr, std::int64_t n, void *x, + library_data_t x_type, int64_t incx, void *y, + library_data_t y_type, int64_t incy, const void *param, + library_data_t param_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + detail::rotm_impl(q, n, x, incx, y, incy, param); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + detail::rotm_impl(q, n, x, incx, y, incy, param); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +inline void copy(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type, y_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::copy_impl(q, n, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::copy_impl(q, n, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + detail::copy_impl, std::complex>(q, n, x, incx, + y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double): { + detail::copy_impl, std::complex>(q, n, x, incx, + y, incy); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +inline void swap(descriptor_ptr desc_ptr, std::int64_t n, void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type, y_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::swap_impl(q, n, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::swap_impl(q, n, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + detail::swap_impl, std::complex>(q, n, x, incx, + y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double): { + detail::swap_impl, std::complex>(q, n, x, incx, + y, incy); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +inline void asum(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type, result_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::asum_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::asum_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + detail::asum_impl, std::complex>(q, n, x, incx, + result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double): { + detail::asum_impl, std::complex>(q, n, x, incx, + result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +template +inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, std::int64_t *result) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = detail::get_type_combination_id(x_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float): { + detail::iamaxmin_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::real_double): { + detail::iamaxmin_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float): { + detail::iamaxmin_impl>(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double): { + detail::iamaxmin_impl>(q, n, x, incx, result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +template +inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, int *result) { + dpct::blas::wrapper_int_to_int64_out wrapper(desc_ptr->get_queue(), result); + iamaxmin(desc_ptr, n, x, x_type, incx, wrapper.get()); +} + /// Finds the least squares solutions for a batch of overdetermined linear /// systems. Uses the QR factorization to solve a grouped batch of linear /// systems with full rank matrices. From 66f38e8e093eeddaec0c1bd0731d0e65d4087624 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Thu, 29 Aug 2024 16:22:53 +0800 Subject: [PATCH 05/21] Add tests Signed-off-by: Jiang, Zhiwei --- clang/test/dpct/cublas-usm-11.cu | 22 ++++++++++++++++++++++ clang/test/dpct/cublas_64.cu | 22 ++++++++++++++++++++++ clang/test/dpct/cublas_64_usm.cu | 22 ++++++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index 75c92f358e11..e0c4c0fbc068 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -53,3 +53,25 @@ void foo3() { cublasGetMathMode(handle, &Mathmode); cublasSetMathMode(handle, Mathmode); } + +void foo4() { + cublasHandle_t handle; + int n; + void *x, *y; + int incx, incy; + void *res; + int *idx; + void *param; + // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); + cublasCopyEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasSwapEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasIamaxEx(handle, n, x, CUDA_R_32F, incx, idx); + cublasIaminEx(handle, n, x, CUDA_R_32F, incx, idx); + cublasAsumEx(handle, n, x, CUDA_R_32F, incx, res, CUDA_R_32F, CUDA_R_32F); + cublasRotmEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy, param, CUDA_R_32F, CUDA_R_32F); +} diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index 5f08020b6678..edd52f0fe4da 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -516,3 +516,25 @@ void foo() { // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec)); status = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec, algo); } + +void foo2() { + cublasHandle_t handle; + int n; + void *x, *y; + int incx, incy; + void *res; + int64_t *idx; + void *param; + // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); + cublasCopyEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasSwapEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasIamaxEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasIaminEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasAsumEx_64(handle, n, x, CUDA_R_32F, incx, res, CUDA_R_32F, CUDA_R_32F); + cublasRotmEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy, param, CUDA_R_32F, CUDA_R_32F); +} diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index 57c4c5b1c280..820e928b6bf6 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -518,3 +518,25 @@ void foo() { status = cublasGemmBatchedEx(handle, transa, transb, m, n, k, alpha, a_array, type_a, lda, b_array, type_b, ldb, beta, c_array, type_c, ldc, batch, type_exec, algo); status = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec, algo); } + +void foo2() { + cublasHandle_t handle; + int n; + void *x, *y; + int incx, incy; + void *res; + int64_t *idx; + void *param; + // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); + cublasCopyEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasSwapEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasIamaxEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasIaminEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasAsumEx_64(handle, n, x, CUDA_R_32F, incx, res, CUDA_R_32F, CUDA_R_32F); + cublasRotmEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy, param, CUDA_R_32F, CUDA_R_32F); +} From 6d8c25e590855ae26d3d71030cc5bd11ed566b75 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 30 Aug 2024 10:06:53 +0800 Subject: [PATCH 06/21] Adjust the code Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/ASTTraversal.cpp | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index a6c12481fa61..7d53e6499bd2 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -4168,17 +4168,13 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasSgeqrfBatched", "cublasDgeqrfBatched", "cublasCgeqrfBatched", "cublasZgeqrfBatched", "cublasSgelsBatched", "cublasDgelsBatched", "cublasCgelsBatched", "cublasZgelsBatched", "cublasGemmEx", - "cublasSgemmEx", "cublasCgemmEx", "cublasNrm2Ex", "cublasNrm2Ex_64", - "cublasDotEx", "cublasDotEx_64", "cublasDotcEx", "cublasDotcEx_64", - "cublasScalEx", "cublasScalEx_64", "cublasAxpyEx", "cublasAxpyEx_64", - "cublasRotEx", "cublasRotEx_64", "cublasGemmBatchedEx", - "cublasGemmBatchedEx_64", "cublasGemmStridedBatchedEx", - "cublasGemmStridedBatchedEx_64", "cublasSdgmm", "cublasDdgmm", - "cublasCdgmm", "cublasZdgmm", "cublasSgeam", "cublasDgeam", - "cublasCgeam", "cublasZgeam", "cublasCopyEx", "cublasSwapEx", - "cublasIamaxEx", "cublasIaminEx", "cublasAsumEx", "cublasRotmEx", - "cublasCopyEx_64", "cublasSwapEx_64", "cublasIamaxEx_64", - "cublasIaminEx_64", "cublasAsumEx_64", "cublasRotmEx_64", + "cublasSgemmEx", "cublasCgemmEx", "cublasNrm2Ex", "cublasDotEx", + "cublasDotcEx", "cublasScalEx", "cublasAxpyEx", "cublasRotEx", + "cublasGemmBatchedEx", "cublasGemmStridedBatchedEx", "cublasSdgmm", + "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", "cublasSgeam", + "cublasDgeam", "cublasCgeam", "cublasZgeam", "cublasCopyEx", + "cublasSwapEx", "cublasIamaxEx", "cublasIaminEx", "cublasAsumEx", + "cublasRotmEx", /*Legacy API*/ "cublasInit", "cublasShutdown", "cublasGetError", "cublasSetKernelStream", "cublasGetVersion", @@ -4273,6 +4269,12 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasCsyrkx_64", "cublasZsyrkx_64", "cublasCherkx_64", "cublasZherkx_64", "cublasHgemm_64", "cublasCgemm3m_64", "cublasZgemm3m_64", + /*extension*/ + "cublasNrm2Ex_64", "cublasDotEx_64", "cublasDotcEx_64", + "cublasScalEx_64", "cublasAxpyEx_64", "cublasRotEx_64", + "cublasGemmBatchedEx_64", "cublasGemmStridedBatchedEx_64", + "cublasCopyEx_64", "cublasSwapEx_64", "cublasIamaxEx_64", + "cublasIaminEx_64", "cublasAsumEx_64", "cublasRotmEx_64", /*cublasLt*/ "cublasLtCreate", "cublasLtDestroy", "cublasLtMatmulDescCreate", "cublasLtMatmulDescDestroy", "cublasLtMatmulDescSetAttribute", From 4746e49d5b82aedff0d04b31daf1cf480d3c9369 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 30 Aug 2024 12:52:09 +0800 Subject: [PATCH 07/21] Support more API migration Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/APINamesCUBLAS.inc | 59 +++++++++++++----------------- clang/lib/DPCT/APINames_cuBLAS.inc | 10 ++--- clang/lib/DPCT/ASTTraversal.cpp | 16 ++++---- clang/lib/DPCT/MapNames.cpp | 5 +++ clang/test/dpct/cublas-usm-11.cu | 11 +++++- clang/test/dpct/cublas_64.cu | 14 ++++++- clang/test/dpct/cublas_64_usm.cu | 18 +++++++-- 7 files changed, 81 insertions(+), 52 deletions(-) diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index be89f1eaa64b..fc6b104b7684 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -207,16 +207,6 @@ ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10))))) -ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( - HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY( - "cublasGemmEx", - CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), - BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), - ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17))))) - ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY( @@ -727,29 +717,32 @@ WARNING_FACTORY_ENTRY( Diagnostics::TRNA_WARNING_ERROR_HANDLING_API_COMMENTED, ARG("The call was replaced by a placeholder string")) -ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( - HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY( - "cublasSgemmEx", - CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), - BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), - ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), - ARG(MapNames::getLibraryHelperNamespace() + - "library_data_t::real_float"))))) - -ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( - HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY( - "cublasCgemmEx", - CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), - BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), - ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), - ARG(MapNames::getLibraryHelperNamespace() + - "library_data_t::complex_float"))))) +#define GEMM_EX(NAME, COMPUTE_TYPE) \ + ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( \ + HelperFeatureEnum::device_ext, \ + CALL_FACTORY_ENTRY( \ + #NAME, \ + CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), \ + BLAS_ENUM_ARG(1, \ + clang::dpct::BLASEnumExpr::BLASEnumType::Trans), \ + BLAS_ENUM_ARG(2, \ + clang::dpct::BLASEnumExpr::BLASEnumType::Trans), \ + ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), \ + ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), \ + ARG(COMPUTE_TYPE))))) +GEMM_EX(cublasSgemmEx, + MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") +GEMM_EX(cublasCgemmEx, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +GEMM_EX(cublasCgemm3mEx, "oneapi::mkl::blas::compute_mode::complex_3m") +GEMM_EX(cublasGemmEx, 17) +GEMM_EX(cublasSgemmEx_64, + MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") +GEMM_EX(cublasCgemmEx_64, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +GEMM_EX(cublasCgemm3mEx_64, "oneapi::mkl::blas::compute_mode::complex_3m") +GEMM_EX(cublasGemmEx_64, 17) +#undef GEMM_EX #define SYRK(NAME, TYPE, IS_COMPLEX) \ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ diff --git a/clang/lib/DPCT/APINames_cuBLAS.inc b/clang/lib/DPCT/APINames_cuBLAS.inc index 2873c23b589e..afb8dfbd0b36 100644 --- a/clang/lib/DPCT/APINames_cuBLAS.inc +++ b/clang/lib/DPCT/APINames_cuBLAS.inc @@ -527,7 +527,7 @@ ENTRY(cublasZtrmm, cublasZtrmm, true, NO_FLAG, P4, "Successful") ENTRY(cublasAsumEx, cublasAsumEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCgemm3mBatched, cublasCgemm3mBatched, false, NO_FLAG, P4, "comment") -ENTRY(cublasCgemm3mEx, cublasCgemm3mEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasCgemm3mEx, cublasCgemm3mEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCopyEx, cublasCopyEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDotcEx, cublasDotcEx, true, NO_FLAG, P4, "Successful") ENTRY(cublasGetCudartVersion, cublasGetCudartVersion, false, NO_FLAG, P4, "comment") @@ -774,12 +774,12 @@ ENTRY(cublasHSSgemvStridedBatched_64, cublasHSSgemvStridedBatched_64, false, NO_ ENTRY(cublasTSTgemvStridedBatched_64, cublasTSTgemvStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasTSSgemvStridedBatched_64, cublasTSSgemvStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCgemm3m_64, cublasCgemm3m_64, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasCgemm3mEx_64, cublasCgemm3mEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasCgemm3mEx_64, cublasCgemm3mEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasZgemm3m_64, cublasZgemm3m_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasHgemm_64, cublasHgemm_64, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasSgemmEx_64, cublasSgemmEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasGemmEx_64, cublasGemmEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCgemmEx_64, cublasCgemmEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasSgemmEx_64, cublasSgemmEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasGemmEx_64, cublasGemmEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCgemmEx_64, cublasCgemmEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCsyrkEx_64, cublasCsyrkEx_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCsyrk3mEx_64, cublasCsyrk3mEx_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCherkEx_64, cublasCherkEx_64, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 7d53e6499bd2..019ec40a1b91 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -4168,13 +4168,13 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasSgeqrfBatched", "cublasDgeqrfBatched", "cublasCgeqrfBatched", "cublasZgeqrfBatched", "cublasSgelsBatched", "cublasDgelsBatched", "cublasCgelsBatched", "cublasZgelsBatched", "cublasGemmEx", - "cublasSgemmEx", "cublasCgemmEx", "cublasNrm2Ex", "cublasDotEx", - "cublasDotcEx", "cublasScalEx", "cublasAxpyEx", "cublasRotEx", - "cublasGemmBatchedEx", "cublasGemmStridedBatchedEx", "cublasSdgmm", - "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", "cublasSgeam", - "cublasDgeam", "cublasCgeam", "cublasZgeam", "cublasCopyEx", - "cublasSwapEx", "cublasIamaxEx", "cublasIaminEx", "cublasAsumEx", - "cublasRotmEx", + "cublasSgemmEx", "cublasCgemmEx", "cublasCgemm3mEx", "cublasNrm2Ex", + "cublasDotEx", "cublasDotcEx", "cublasScalEx", "cublasAxpyEx", + "cublasRotEx", "cublasGemmBatchedEx", "cublasGemmStridedBatchedEx", + "cublasSdgmm", "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", + "cublasSgeam", "cublasDgeam", "cublasCgeam", "cublasZgeam", + "cublasCopyEx", "cublasSwapEx", "cublasIamaxEx", "cublasIaminEx", + "cublasAsumEx", "cublasRotmEx", /*Legacy API*/ "cublasInit", "cublasShutdown", "cublasGetError", "cublasSetKernelStream", "cublasGetVersion", @@ -4275,6 +4275,8 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasGemmBatchedEx_64", "cublasGemmStridedBatchedEx_64", "cublasCopyEx_64", "cublasSwapEx_64", "cublasIamaxEx_64", "cublasIaminEx_64", "cublasAsumEx_64", "cublasRotmEx_64", + "cublasSgemmEx_64", "cublasCgemmEx_64", "cublasCgemm3mEx_64", + "cublasGemmEx_64", /*cublasLt*/ "cublasLtCreate", "cublasLtDestroy", "cublasLtMatmulDescCreate", "cublasLtMatmulDescDestroy", "cublasLtMatmulDescSetAttribute", diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index 2ce96af5f730..3d074d7a3245 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -2170,6 +2170,7 @@ void MapNames::setExplicitNamespaceMap( {"cublasGemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasSgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasCgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCgemm3mEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasGemmBatchedEx", getLibraryHelperNamespace() + "blas::gemm_batch"}, {"cublasGemmBatchedEx_64", getLibraryHelperNamespace() + "blas::gemm_batch"}, @@ -2466,6 +2467,10 @@ void MapNames::setExplicitNamespaceMap( getLibraryHelperNamespace() + "blas::iamaxmin"}, {"cublasAsumEx_64", getLibraryHelperNamespace() + "blas::asum"}, {"cublasRotmEx_64", getLibraryHelperNamespace() + "blas::rotm"}, + {"cublasSgemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCgemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCgemm3mEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasGemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, // cublasLt {"cublasLtCreate", "new " + getLibraryHelperNamespace() + "blas_gemm::experimental::descriptor"}, diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index e0c4c0fbc068..a5a25f11a195 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -56,7 +56,7 @@ void foo3() { void foo4() { cublasHandle_t handle; - int n; + int m, n, k; void *x, *y; int incx, incy; void *res; @@ -74,4 +74,13 @@ void foo4() { cublasIaminEx(handle, n, x, CUDA_R_32F, incx, idx); cublasAsumEx(handle, n, x, CUDA_R_32F, incx, res, CUDA_R_32F, CUDA_R_32F); cublasRotmEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy, param, CUDA_R_32F, CUDA_R_32F); + + cublasOperation_t transa; + cublasOperation_t transb; + cuComplex *alpha, *beta; + void *A, *B, *C; + int lda, ldb, ldc; + cudaDataType_t a_type, b_type, c_type; + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc); } diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index edd52f0fe4da..383a9eb0f385 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -513,8 +513,18 @@ void foo() { int64_t stride_a; int64_t stride_b; int64_t stride_c; - // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec)); - status = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec, algo); + cublasComputeType_t type_compute; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute)); + status = cublasGemmStridedBatchedEx_64(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute, algo); + + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); + cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); + cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasCgemm3mEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasGemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute, algo); } void foo2() { diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index 820e928b6bf6..841defaac05a 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -513,10 +513,20 @@ void foo() { int64_t stride_a; int64_t stride_b; int64_t stride_c; - // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, const_cast(a_array), type_a, lda, const_cast(b_array), type_b, ldb, beta, c_array, type_c, ldc, batch, type_exec)); - // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec)); - status = cublasGemmBatchedEx(handle, transa, transb, m, n, k, alpha, a_array, type_a, lda, b_array, type_b, ldb, beta, c_array, type_c, ldc, batch, type_exec, algo); - status = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_exec, algo); + cublasComputeType_t type_compute; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, const_cast(a_array), type_a, lda, const_cast(b_array), type_b, ldb, beta, c_array, type_c, ldc, batch, type_compute)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute)); + status = cublasGemmBatchedEx_64(handle, transa, transb, m, n, k, alpha, a_array, type_a, lda, b_array, type_b, ldb, beta, c_array, type_c, ldc, batch, type_compute, algo); + status = cublasGemmStridedBatchedEx_64(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute, algo); + + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); + cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); + cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasCgemm3mEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasGemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute, algo); } void foo2() { From d298d3259c346df31cf39b93c3ef3a1135559a25 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 2 Sep 2024 16:29:06 +0800 Subject: [PATCH 08/21] Support 4 APIs Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/APINamesCUBLAS.inc | 27 +++++++ clang/lib/DPCT/APINames_cuBLAS.inc | 16 ++--- clang/lib/DPCT/ASTTraversal.cpp | 6 +- clang/lib/DPCT/MapNames.cpp | 10 +++ .../dpct-rt/include/dpct/blas_utils.hpp | 70 +++++++++++++++++++ clang/test/dpct/cublas-usm-11.cu | 12 ++++ clang/test/dpct/cublas_64.cu | 10 +++ clang/test/dpct/cublas_64_usm.cu | 10 +++ 8 files changed, 151 insertions(+), 10 deletions(-) diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index fc6b104b7684..974459f1e415 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -744,6 +744,33 @@ GEMM_EX(cublasCgemm3mEx_64, "oneapi::mkl::blas::compute_mode::complex_3m") GEMM_EX(cublasGemmEx_64, 17) #undef GEMM_EX +#define SYHERK(NAME, IS_HERMITIAN, COMPUTE_TYPE) \ + ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( \ + HelperFeatureEnum::device_ext, \ + CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + \ + "blas::syherk<" + #IS_HERMITIAN + ">", \ + ARG(0), \ + BLAS_ENUM_ARG( \ + 1, clang::dpct::BLASEnumExpr::BLASEnumType::Uplo), \ + BLAS_ENUM_ARG( \ + 2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), \ + ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), \ + ARG(10), ARG(11), ARG(12), ARG(COMPUTE_TYPE))))) +SYHERK(cublasCsyrkEx, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCsyrk3mEx, false, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCherkEx, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCherk3mEx, true, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCsyrkEx_64, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCsyrk3mEx_64, false, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCherkEx_64, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCherk3mEx_64, true, "oneapi::mkl::blas::compute_mode::complex_3m") +#undef SYHERK + #define SYRK(NAME, TYPE, IS_COMPLEX) \ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ #NAME, \ diff --git a/clang/lib/DPCT/APINames_cuBLAS.inc b/clang/lib/DPCT/APINames_cuBLAS.inc index afb8dfbd0b36..b5b2ecf07c6b 100644 --- a/clang/lib/DPCT/APINames_cuBLAS.inc +++ b/clang/lib/DPCT/APINames_cuBLAS.inc @@ -279,10 +279,10 @@ ENTRY(cublasCgemmEx, cublasCgemmEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasGemmEx, cublasGemmEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasGemmBatchedEx, cublasGemmBatchedEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasGemmStridedBatchedEx, cublasGemmStridedBatchedEx, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasCsyrkEx, cublasCsyrkEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCsyrk3mEx, cublasCsyrk3mEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherkEx, cublasCherkEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherk3mEx, cublasCherk3mEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasCsyrkEx, cublasCsyrkEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCsyrk3mEx, cublasCsyrk3mEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherkEx, cublasCherkEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherk3mEx, cublasCherk3mEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasNrm2Ex, cublasNrm2Ex, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasAxpyEx, cublasAxpyEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDotEx, cublasDotEx, true, NO_FLAG, P4, "DPCT1020") @@ -780,10 +780,10 @@ ENTRY(cublasHgemm_64, cublasHgemm_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasSgemmEx_64, cublasSgemmEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasGemmEx_64, cublasGemmEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCgemmEx_64, cublasCgemmEx_64, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasCsyrkEx_64, cublasCsyrkEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCsyrk3mEx_64, cublasCsyrk3mEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherkEx_64, cublasCherkEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherk3mEx_64, cublasCherk3mEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasCsyrkEx_64, cublasCsyrkEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCsyrk3mEx_64, cublasCsyrk3mEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherkEx_64, cublasCherkEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherk3mEx_64, cublasCherk3mEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasSsyrkx_64, cublasSsyrkx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDsyrkx_64, cublasDsyrkx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCsyrkx_64, cublasCsyrkx_64, true, NO_FLAG, P4, "DPCT1020") diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 019ec40a1b91..409ce70b305a 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -4174,7 +4174,8 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasSdgmm", "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", "cublasSgeam", "cublasDgeam", "cublasCgeam", "cublasZgeam", "cublasCopyEx", "cublasSwapEx", "cublasIamaxEx", "cublasIaminEx", - "cublasAsumEx", "cublasRotmEx", + "cublasAsumEx", "cublasRotmEx", "cublasCsyrkEx", "cublasCsyrk3mEx", + "cublasCherkEx", "cublasCherk3mEx", /*Legacy API*/ "cublasInit", "cublasShutdown", "cublasGetError", "cublasSetKernelStream", "cublasGetVersion", @@ -4276,7 +4277,8 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasCopyEx_64", "cublasSwapEx_64", "cublasIamaxEx_64", "cublasIaminEx_64", "cublasAsumEx_64", "cublasRotmEx_64", "cublasSgemmEx_64", "cublasCgemmEx_64", "cublasCgemm3mEx_64", - "cublasGemmEx_64", + "cublasGemmEx_64", "cublasCsyrkEx_64", "cublasCsyrk3mEx_64", + "cublasCherkEx_64", "cublasCherk3mEx_64", /*cublasLt*/ "cublasLtCreate", "cublasLtDestroy", "cublasLtMatmulDescCreate", "cublasLtMatmulDescDestroy", "cublasLtMatmulDescSetAttribute", diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index 3d074d7a3245..20ba30377e26 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -2471,6 +2471,16 @@ void MapNames::setExplicitNamespaceMap( {"cublasCgemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasCgemm3mEx_64", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasGemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCsyrkEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCsyrk3mEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherkEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherk3mEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCsyrkEx_64", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCsyrk3mEx_64", + getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherkEx_64", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherk3mEx_64", + getLibraryHelperNamespace() + "blas::syherk"}, // cublasLt {"cublasLtCreate", "new " + getLibraryHelperNamespace() + "blas_gemm::experimental::descriptor"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 48da9c3ad96a..47282ecf9486 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -725,6 +725,26 @@ inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, batch_size DPCT_COMPUTE_MODE_ARG); } +template +inline void syherk_impl(sycl::queue &q, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, int n, int k, + const void *alpha, const void *a, int lda, + const void *beta, void *c, + int ldc DPCT_COMPUTE_MODE_PARAM) { + T alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + T beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_c = get_memory(c); + if constexpr (is_hermitian) + oneapi::mkl::blas::column_major::herk(q, uplo, trans, n, k, alpha_value, + data_a, lda, beta_value, data_c, + ldc DPCT_COMPUTE_MODE_ARG); + else + oneapi::mkl::blas::column_major::syrk(q, uplo, trans, n, k, alpha_value, + data_a, lda, beta_value, data_c, + ldc DPCT_COMPUTE_MODE_ARG); +} + template inline void rk_impl(sycl::queue &q, oneapi::mkl::uplo uplo, oneapi::mkl::transpose trans, int n, int k, const T *alpha, @@ -1956,6 +1976,56 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, } } +template +inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, + const void *alpha, const void *a, library_data_t a_type, + std::int64_t lda, const void *beta, void *c, + library_data_t c_type, std::int64_t ldc, + std::variant ct) { + sycl::queue q = desc_ptr->get_queue(); +#ifdef __INTEL_MKL__ + oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; + if (auto ct_p = std::get_if(&ct)) { + cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(), + a_type == library_data_t::complex_float || + a_type == library_data_t::complex_double); + } else { + cm = deduce_compute_mode(std::nullopt, desc_ptr->get_math_mode(), + a_type == library_data_t::complex_float || + a_type == library_data_t::complex_double); + } +#endif + std::uint64_t key = dpct::detail::get_type_combination_id(a_type, c_type); + if (!is_hermitian && + dpct::detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float) == key) { + dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, + a, lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); + } else if (!is_hermitian && dpct::detail::get_type_combination_id( + library_data_t::real_double, + library_data_t::real_double) == key) { + dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, + a, lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); + } else if (dpct::detail::get_type_combination_id( + library_data_t::complex_float, + library_data_t::complex_float) == key) { + dpct::detail::syherk_impl>( + q, uplo, trans, n, k, alpha, a, lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); + } else if (dpct::detail::get_type_combination_id( + library_data_t::complex_double, + library_data_t::complex_double) == key) { + dpct::detail::syherk_impl>( + q, uplo, trans, n, k, alpha, a, lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); + } else { + throw std::runtime_error("the combination of data type is unsupported"); + } +} + /// This routines perform a special rank-k update of a symmetric matrix C by /// general matrices A and B. /// \param [in] desc_ptr Descriptor. diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index a5a25f11a195..72d91153a513 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -83,4 +83,16 @@ void foo4() { cudaDataType_t a_type, b_type, c_type; // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc); + + cublasFillMode_t uplo; + cublasOperation_t trans; + float *alpha_s, *beta_s; + // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + cublasCsyrkEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); + cublasCsyrk3mEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); + cublasCherkEx(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc); + cublasCherk3mEx(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc); } diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index 383a9eb0f385..e1498f4f2151 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -525,6 +525,16 @@ void foo() { cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); cublasCgemm3mEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); cublasGemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute, algo); + + cublasOperation_t trans; + // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); + cublasCherk3mEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); } void foo2() { diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index 841defaac05a..cbde8ee81446 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -527,6 +527,16 @@ void foo() { cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); cublasCgemm3mEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); cublasGemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute, algo); + + cublasOperation_t trans; + // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); + cublasCherk3mEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); } void foo2() { From 6cbf703651dd3042d01b5aa4600936d02812f760 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Thu, 12 Sep 2024 15:29:07 +0800 Subject: [PATCH 09/21] Update Signed-off-by: Jiang, Zhiwei --- .../dpct-rt/include/dpct/blas_utils.hpp | 75 ++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 47282ecf9486..cb587844c290 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -1976,6 +1976,23 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, } } +/// Performs a symmetric/hermitian rank-k update. +/// \tparam is_hermitian True means current matrix is hermitian. +/// \param [in] desc_ptr Descriptor. +/// \param [in] uplo Specifies whether matrix c is upper or lower triangular. +/// \param [in] trans Specifies op(a), the transposition operation applied to +/// matrix a. +/// \param [in] n Number of rows and columns of matrix c. +/// \param [in] k Number of columns of matrix op(a). +/// \param [in] alpha Scaling factor for the rank-k update. +/// \param [in] a Input matrix a. +/// \param [in] a_type Data type of the matrix a. +/// \param [in] lda Leading dimension of the matrix a. +/// \param [in] beta Scaling factor for the rank-k update. +/// \param [in, out] c Input/Output matrix c. +/// \param [in] c_type Data type of the matrix c. +/// \param [in] ldc Leading dimension of the matrix c. +/// \param [in] ct Compute type. template inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, @@ -2455,6 +2472,17 @@ inline void rot(descriptor_ptr desc_ptr, std::int64_t n, void *x, } } +/// Performs modified Givens rotation of points in the plane. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [in] param Array of 5 parameters. +/// \param [in] param_type Data type of \p param. inline void rotm(descriptor_ptr desc_ptr, std::int64_t n, void *x, library_data_t x_type, int64_t incx, void *y, library_data_t y_type, int64_t incy, const void *param, @@ -2479,6 +2507,15 @@ inline void rotm(descriptor_ptr desc_ptr, std::int64_t n, void *x, } } +/// Copies a vector to another vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] y Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. inline void copy(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, void *y, library_data_t y_type, std::int64_t incy) { @@ -2512,6 +2549,15 @@ inline void copy(descriptor_ptr desc_ptr, std::int64_t n, const void *x, } } +/// Swaps a vector with another vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. inline void swap(descriptor_ptr desc_ptr, std::int64_t n, void *x, library_data_t x_type, std::int64_t incx, void *y, library_data_t y_type, std::int64_t incy) { @@ -2545,6 +2591,14 @@ inline void swap(descriptor_ptr desc_ptr, std::int64_t n, void *x, } } +/// Computes the sum of magnitudes of the vector elements. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The scalar result. +/// \param [in] result_type Data type of \p result. inline void asum(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, void *result, library_data_t result_type) { @@ -2578,9 +2632,19 @@ inline void asum(descriptor_ptr desc_ptr, std::int64_t n, const void *x, } } +/// Finds the index of the element with the largest/smallest absolute value in a +/// vector. +/// \tparam is_max True means finding the the largest absolute value index. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the maximal/minimum element. template inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, - library_data_t x_type, std::int64_t incx, std::int64_t *result) { + library_data_t x_type, std::int64_t incx, + std::int64_t *result) { sycl::queue q = desc_ptr->get_queue(); std::uint64_t key = detail::get_type_combination_id(x_type); switch (key) { @@ -2605,6 +2669,15 @@ inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, } } +/// Finds the index of the element with the largest/smallest absolute value in a +/// vector. +/// \tparam is_max True means finding the the largest absolute value index. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the maximal/minimum element. template inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, int *result) { From 1eee1951b0e356a88b69795010963bc8217e3ea0 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 13 Sep 2024 08:29:56 +0800 Subject: [PATCH 10/21] Fix helper Signed-off-by: Jiang, Zhiwei --- .../dpct-rt/include/dpct/blas_utils.hpp | 331 +++++++++--------- 1 file changed, 167 insertions(+), 164 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index cb587844c290..7a74c76f2191 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -602,19 +602,6 @@ inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, #endif } -template -inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, - std::int64_t incx, void *res) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else - auto data_x = get_memory(x); - auto data_res = get_memory(res); - oneapi::mkl::blas::column_major::asum(q, n, data_x, incx, data_res); -#endif -} - template inline void iamaxmin_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, std::int64_t *res) { @@ -622,7 +609,7 @@ inline void iamaxmin_impl(sycl::queue &q, std::int64_t n, const void *x, throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); #else - auto data_x = get_memory(x); + auto data_x = get_memory(x); auto data_res = get_memory(res); if constexpr (is_max) oneapi::mkl::blas::column_major::iamax(q, n, data_x, incx, data_res, @@ -2232,31 +2219,34 @@ inline void nrm2(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, void *result, library_data_t result_type) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type, result_type); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, result_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::nrm2_impl(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::nrm2_impl(q, n, x, incx, result); break; } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::nrm2_impl(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::nrm2_impl(q, n, x, incx, result); break; } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::nrm2_impl, float>(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + ::dpct::detail::nrm2_impl, float>(q, n, x, incx, + result); break; } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::nrm2_impl, double>(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + ::dpct::detail::nrm2_impl, double>(q, n, x, incx, + result); break; } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::nrm2_impl(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + ::dpct::detail::nrm2_impl(q, n, x, incx, result); break; } default: @@ -2280,8 +2270,8 @@ inline void dot(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t y_type, std::int64_t incy, void *result, library_data_t result_type) { sycl::queue q = desc_ptr->get_queue(); - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); + ::dpct::detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); } /// Computes the dot product of two vectors, conjugating the first vector. @@ -2300,8 +2290,8 @@ inline void dotc(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t y_type, std::int64_t incy, void *result, library_data_t result_type) { sycl::queue q = desc_ptr->get_queue(); - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); + ::dpct::detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); } /// Computes the product of a vector by a scalar. @@ -2316,31 +2306,33 @@ inline void scal(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, library_data_t alpha_type, void *x, library_data_t x_type, std::int64_t incx) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float): { - detail::scal_impl(q, n, alpha, x, incx); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float): { + ::dpct::detail::scal_impl(q, n, alpha, x, incx); break; } - case detail::get_type_combination_id(library_data_t::real_double): { - detail::scal_impl(q, n, alpha, x, incx); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double): { + ::dpct::detail::scal_impl(q, n, alpha, x, incx); break; } - case detail::get_type_combination_id(library_data_t::complex_float): { - detail::scal_impl, std::complex>(q, n, alpha, x, - incx); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float): { + ::dpct::detail::scal_impl, std::complex>( + q, n, alpha, x, incx); break; } - case detail::get_type_combination_id(library_data_t::complex_double): { - detail::scal_impl, std::complex>(q, n, alpha, - x, incx); + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double): { + ::dpct::detail::scal_impl, std::complex>( + q, n, alpha, x, incx); break; } - case detail::get_type_combination_id(library_data_t::real_half): { + case ::dpct::detail::get_type_combination_id(library_data_t::real_half): { float alpha_value = dpct::get_value(reinterpret_cast(alpha), q); sycl::half alaph_half(alpha_value); - detail::scal_impl(q, n, &alaph_half, x, incx); + ::dpct::detail::scal_impl(q, n, &alaph_half, x, + incx); break; } default: @@ -2364,37 +2356,38 @@ inline void axpy(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, library_data_t x_type, std::int64_t incx, void *y, library_data_t y_type, std::int64_t incy) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type, alpha_type); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, alpha_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::axpy_impl(q, n, alpha, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::axpy_impl(q, n, alpha, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::axpy_impl, std::complex>(q, n, alpha, x, - incx, y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::axpy_impl, std::complex>( + q, n, alpha, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::axpy_impl, std::complex>( + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::axpy_impl, std::complex>( q, n, alpha, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_float): { + case ::dpct::detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_float): { float alpha_value = dpct::get_value(reinterpret_cast(alpha), q); sycl::half alaph_half(alpha_value); - detail::axpy_impl(q, n, &alaph_half, x, incx, y, - incy); + ::dpct::detail::axpy_impl(q, n, &alaph_half, x, + incx, y, incy); break; } default: @@ -2419,52 +2412,55 @@ inline void rot(descriptor_ptr desc_ptr, std::int64_t n, void *x, library_data_t y_type, std::int64_t incy, const void *c, const void *s, library_data_t cs_type) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type, cs_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::rot_impl(q, n, x, incx, y, incy, c, s); break; } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::rot_impl(q, n, x, incx, y, incy, c, + s); break; } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::rot_impl, float, float>(q, n, x, incx, y, incy, - c, s); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + ::dpct::detail::rot_impl, float, float>(q, n, x, incx, + y, incy, c, s); break; } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::rot_impl, double, double>(q, n, x, incx, y, - incy, c, s); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + ::dpct::detail::rot_impl, double, double>( + q, n, x, incx, y, incy, c, s); break; } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::rot_impl, float, std::complex>( + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::rot_impl, float, std::complex>( q, n, x, incx, y, incy, c, s); break; } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::rot_impl, double, std::complex>( - q, n, x, incx, y, incy, c, s); + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::rot_impl, double, + std::complex>(q, n, x, incx, y, incy, c, + s); break; } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::rot_impl(q, n, x, incx, y, incy, - c, s); + case ::dpct::detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + ::dpct::detail::rot_impl(q, n, x, incx, + y, incy, c, s); break; } - case detail::get_type_combination_id(library_data_t::real_bfloat16, - library_data_t::real_bfloat16): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); + case ::dpct::detail::get_type_combination_id(library_data_t::real_bfloat16, + library_data_t::real_bfloat16): { + ::dpct::detail::rot_impl(q, n, x, incx, y, incy, c, + s); break; } default: @@ -2488,18 +2484,21 @@ inline void rotm(descriptor_ptr desc_ptr, std::int64_t n, void *x, library_data_t y_type, int64_t incy, const void *param, library_data_t param_type) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, param_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float, - library_data_t::real_float): { - detail::rotm_impl(q, n, x, incx, y, incy, param); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::rotm_impl(q, n, x, incx, y, incy, + param); break; } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double, - library_data_t::real_double): { - detail::rotm_impl(q, n, x, incx, y, incy, param); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::rotm_impl(q, n, x, incx, y, incy, + param); break; } default: @@ -2520,28 +2519,28 @@ inline void copy(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, void *y, library_data_t y_type, std::int64_t incy) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type, y_type); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type, y_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::copy_impl(q, n, x, incx, y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::copy_impl(q, n, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::copy_impl(q, n, x, incx, y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::copy_impl(q, n, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::copy_impl, std::complex>(q, n, x, incx, - y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::copy_impl, std::complex>( + q, n, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::copy_impl, std::complex>(q, n, x, incx, - y, incy); + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::copy_impl, std::complex>( + q, n, x, incx, y, incy); break; } default: @@ -2562,28 +2561,28 @@ inline void swap(descriptor_ptr desc_ptr, std::int64_t n, void *x, library_data_t x_type, std::int64_t incx, void *y, library_data_t y_type, std::int64_t incy) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type, y_type); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type, y_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::swap_impl(q, n, x, incx, y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::swap_impl(q, n, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::swap_impl(q, n, x, incx, y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::swap_impl(q, n, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::swap_impl, std::complex>(q, n, x, incx, - y, incy); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::swap_impl, std::complex>( + q, n, x, incx, y, incy); break; } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::swap_impl, std::complex>(q, n, x, incx, - y, incy); + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::swap_impl, std::complex>( + q, n, x, incx, y, incy); break; } default: @@ -2603,28 +2602,29 @@ inline void asum(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, void *result, library_data_t result_type) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type, result_type); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, result_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::asum_impl(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::asum_impl(q, n, x, incx, result); break; } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::asum_impl(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::asum_impl(q, n, x, incx, result); break; } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::asum_impl, std::complex>(q, n, x, incx, - result); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + ::dpct::detail::asum_impl, float>(q, n, x, incx, + result); break; } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::asum_impl, std::complex>(q, n, x, incx, - result); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + ::dpct::detail::asum_impl, double>(q, n, x, incx, + result); break; } default: @@ -2646,22 +2646,25 @@ inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, std::int64_t *result) { sycl::queue q = desc_ptr->get_queue(); - std::uint64_t key = detail::get_type_combination_id(x_type); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float): { - detail::iamaxmin_impl(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::real_float): { + ::dpct::detail::iamaxmin_impl(q, n, x, incx, result); break; } - case detail::get_type_combination_id(library_data_t::real_double): { - detail::iamaxmin_impl(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::real_double): { + ::dpct::detail::iamaxmin_impl(q, n, x, incx, result); break; } - case detail::get_type_combination_id(library_data_t::complex_float): { - detail::iamaxmin_impl>(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float): { + ::dpct::detail::iamaxmin_impl, is_max>(q, n, x, incx, + result); break; } - case detail::get_type_combination_id(library_data_t::complex_double): { - detail::iamaxmin_impl>(q, n, x, incx, result); + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double): { + ::dpct::detail::iamaxmin_impl, is_max>(q, n, x, incx, + result); break; } default: @@ -2682,7 +2685,7 @@ template inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, library_data_t x_type, std::int64_t incx, int *result) { dpct::blas::wrapper_int_to_int64_out wrapper(desc_ptr->get_queue(), result); - iamaxmin(desc_ptr, n, x, x_type, incx, wrapper.get()); + iamaxmin(desc_ptr, n, x, x_type, incx, wrapper.get_ptr()); } /// Finds the least squares solutions for a batch of overdetermined linear @@ -3019,7 +3022,7 @@ dot(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, library_data_t result_type) { blas::descriptor desc; desc.set_queue(&q); - blas::dot(q, n, x, x_type, incx, y, y_type, incy, result, result_type); + blas::dot(&desc, n, x, x_type, incx, y, y_type, incy, result, result_type); } /// Computes the dot product of two vectors, conjugating the first vector. @@ -3039,7 +3042,7 @@ dotc(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, library_data_t result_type) { blas::descriptor desc; desc.set_queue(&q); - blas::dotc(q, n, x, x_type, incx, y, y_type, incy, result, result_type); + blas::dotc(&desc, n, x, x_type, incx, y, y_type, incy, result, result_type); } /// Computes the product of a vector by a scalar. @@ -3055,7 +3058,7 @@ scal(sycl::queue &q, int n, const void *alpha, library_data_t alpha_type, void *x, library_data_t x_type, int incx) { blas::descriptor desc; desc.set_queue(&q); - blas::scal(q, n, alpha, alpha_type, x, x_type, incx); + blas::scal(&desc, n, alpha, alpha_type, x, x_type, incx); } /// Computes a vector-scalar product and adds the result to a vector. @@ -3075,7 +3078,7 @@ axpy(sycl::queue &q, int n, const void *alpha, library_data_t alpha_type, library_data_t y_type, int incy) { blas::descriptor desc; desc.set_queue(&q); - blas::axpy(q, n, alpha, alpha_type, x, x_type, incx, y, y_type, incy); + blas::axpy(&desc, n, alpha, alpha_type, x, x_type, incx, y, y_type, incy); } /// Performs rotation of points in the plane. @@ -3096,7 +3099,7 @@ rot(sycl::queue &q, int n, void *x, library_data_t x_type, int incx, void *y, library_data_t cs_type) { blas::descriptor desc; desc.set_queue(&q); - blas::rot(q, n, x, x_type, incx, y, y_type, incy, c, s, cs_type); + blas::rot(&desc, n, x, x_type, incx, y, y_type, incy, c, s, cs_type); } } // namespace dpct #undef DPCT_COMPUTE_MODE_ARG From 5c8ff1de64f9856e7a8ca86df3e079ba5caf9758 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Thu, 19 Sep 2024 10:55:26 +0800 Subject: [PATCH 11/21] Fix helper Signed-off-by: Jiang, Zhiwei --- .../dpct-rt/include/dpct/blas_utils.hpp | 60 +++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 7a74c76f2191..4203d4ddc8c7 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -718,18 +718,23 @@ inline void syherk_impl(sycl::queue &q, oneapi::mkl::uplo uplo, const void *alpha, const void *a, int lda, const void *beta, void *c, int ldc DPCT_COMPUTE_MODE_PARAM) { - T alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - T beta_value = dpct::get_value(reinterpret_cast(beta), q); auto data_a = get_memory(a); auto data_c = get_memory(c); - if constexpr (is_hermitian) + if constexpr (is_hermitian) { + auto alpha_value = dpct::get_value( + reinterpret_cast(alpha), q); + auto beta_value = dpct::get_value( + reinterpret_cast(beta), q); oneapi::mkl::blas::column_major::herk(q, uplo, trans, n, k, alpha_value, data_a, lda, beta_value, data_c, ldc DPCT_COMPUTE_MODE_ARG); - else + } else { + T alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + T beta_value = dpct::get_value(reinterpret_cast(beta), q); oneapi::mkl::blas::column_major::syrk(q, uplo, trans, n, k, alpha_value, data_a, lda, beta_value, data_c, ldc DPCT_COMPUTE_MODE_ARG); + } } template @@ -1472,13 +1477,14 @@ deduce_compute_mode(std::optional ct, math_mode mm, /// \param [in] c_type Data type of the matrix C. /// \param [in] ldc Leading dimension of C. /// \param [in] ct Compute type. -inline void gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n, - std::int64_t k, const void *alpha, const void *a, - library_data_t a_type, std::int64_t lda, const void *b, - library_data_t b_type, std::int64_t ldb, const void *beta, - void *c, library_data_t c_type, std::int64_t ldc, - std::variant ct) { +inline void +gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n, + std::int64_t k, const void *alpha, const void *a, library_data_t a_type, + std::int64_t lda, const void *b, library_data_t b_type, std::int64_t ldb, + const void *beta, void *c, library_data_t c_type, std::int64_t ldc, + std::variant + ct) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); @@ -1486,7 +1492,9 @@ inline void gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, sycl::queue q = desc_ptr->get_queue(); oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; library_data_t scaling_type; - if (auto ct_p = std::get_if(&ct)) { + if (auto ct_p = std::get_if(&ct)) { + cm = *ct_p; + } else if (auto ct_p = std::get_if(&ct)) { cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(), a_type == library_data_t::complex_float || a_type == library_data_t::complex_double); @@ -1981,16 +1989,19 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, /// \param [in] ldc Leading dimension of the matrix c. /// \param [in] ct Compute type. template -inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, - oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, - const void *alpha, const void *a, library_data_t a_type, - std::int64_t lda, const void *beta, void *c, - library_data_t c_type, std::int64_t ldc, - std::variant ct) { +inline void syherk( + descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, + const void *alpha, const void *a, library_data_t a_type, std::int64_t lda, + const void *beta, void *c, library_data_t c_type, std::int64_t ldc, + std::variant + ct) { sycl::queue q = desc_ptr->get_queue(); #ifdef __INTEL_MKL__ oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; - if (auto ct_p = std::get_if(&ct)) { + if (auto ct_p = std::get_if(&ct)) { + cm = *ct_p; + } else if (auto ct_p = std::get_if(&ct)) { cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(), a_type == library_data_t::complex_float || a_type == library_data_t::complex_double); @@ -2004,15 +2015,14 @@ inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, if (!is_hermitian && dpct::detail::get_type_combination_id( library_data_t::real_float, library_data_t::real_float) == key) { - dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, - a, lda, beta, c, - ldc DPCT_COMPUTE_MODE_ARG); + dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, a, lda, + beta, c, ldc DPCT_COMPUTE_MODE_ARG); } else if (!is_hermitian && dpct::detail::get_type_combination_id( library_data_t::real_double, library_data_t::real_double) == key) { - dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, - a, lda, beta, c, - ldc DPCT_COMPUTE_MODE_ARG); + dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, a, + lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); } else if (dpct::detail::get_type_combination_id( library_data_t::complex_float, library_data_t::complex_float) == key) { From f441cd8a2f4fa15d3abff95143dd744182a29f2b Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 20 Sep 2024 10:09:58 +0800 Subject: [PATCH 12/21] Update1 Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/APINamesCUBLAS.inc | 18 ++++++++---- .../dpct-rt/include/dpct/blas_utils.hpp | 28 +++++++++---------- clang/test/dpct/cublas-usm-11.cu | 6 ++-- clang/test/dpct/cublas_64.cu | 6 ++-- clang/test/dpct/cublas_64_usm.cu | 6 ++-- 5 files changed, 34 insertions(+), 30 deletions(-) diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index 974459f1e415..a1631e1af2f4 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -734,13 +734,15 @@ GEMM_EX(cublasSgemmEx, MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") GEMM_EX(cublasCgemmEx, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -GEMM_EX(cublasCgemm3mEx, "oneapi::mkl::blas::compute_mode::complex_3m") +GEMM_EX(cublasCgemm3mEx, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") GEMM_EX(cublasGemmEx, 17) GEMM_EX(cublasSgemmEx_64, MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") GEMM_EX(cublasCgemmEx_64, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -GEMM_EX(cublasCgemm3mEx_64, "oneapi::mkl::blas::compute_mode::complex_3m") +GEMM_EX(cublasCgemm3mEx_64, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") GEMM_EX(cublasGemmEx_64, 17) #undef GEMM_EX @@ -759,16 +761,20 @@ GEMM_EX(cublasGemmEx_64, 17) ARG(10), ARG(11), ARG(12), ARG(COMPUTE_TYPE))))) SYHERK(cublasCsyrkEx, false, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCsyrk3mEx, false, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCsyrk3mEx, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") SYHERK(cublasCherkEx, true, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCherk3mEx, true, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCherk3mEx, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") SYHERK(cublasCsyrkEx_64, false, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCsyrk3mEx_64, false, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCsyrk3mEx_64, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") SYHERK(cublasCherkEx_64, true, MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") -SYHERK(cublasCherk3mEx_64, true, "oneapi::mkl::blas::compute_mode::complex_3m") +SYHERK(cublasCherk3mEx_64, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") #undef SYHERK #define SYRK(NAME, TYPE, IS_COMPLEX) \ diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 4203d4ddc8c7..c5a6813d196e 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -1477,14 +1477,13 @@ deduce_compute_mode(std::optional ct, math_mode mm, /// \param [in] c_type Data type of the matrix C. /// \param [in] ldc Leading dimension of C. /// \param [in] ct Compute type. -inline void -gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n, - std::int64_t k, const void *alpha, const void *a, library_data_t a_type, - std::int64_t lda, const void *b, library_data_t b_type, std::int64_t ldb, - const void *beta, void *c, library_data_t c_type, std::int64_t ldc, - std::variant - ct) { +inline void gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n, + std::int64_t k, const void *alpha, const void *a, + library_data_t a_type, std::int64_t lda, const void *b, + library_data_t b_type, std::int64_t ldb, const void *beta, + void *c, library_data_t c_type, std::int64_t ldc, + std::variant ct) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); @@ -1989,13 +1988,12 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, /// \param [in] ldc Leading dimension of the matrix c. /// \param [in] ct Compute type. template -inline void syherk( - descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, - oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, - const void *alpha, const void *a, library_data_t a_type, std::int64_t lda, - const void *beta, void *c, library_data_t c_type, std::int64_t ldc, - std::variant - ct) { +inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, + const void *alpha, const void *a, library_data_t a_type, + std::int64_t lda, const void *beta, void *c, + library_data_t c_type, std::int64_t ldc, + std::variant ct) { sycl::queue q = desc_ptr->get_queue(); #ifdef __INTEL_MKL__ oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index 72d91153a513..840292ad2744 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -81,16 +81,16 @@ void foo4() { void *A, *B, *C; int lda, ldb, ldc; cudaDataType_t a_type, b_type, c_type; - // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, dpct::library_data_t::complex_float); cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc); cublasFillMode_t uplo; cublasOperation_t trans; float *alpha_s, *beta_s; // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m); cublasCsyrkEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); cublasCsyrk3mEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); cublasCherkEx(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc); diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index e1498f4f2151..0be9808ce43f 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -519,7 +519,7 @@ void foo() { // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); @@ -528,9 +528,9 @@ void foo() { cublasOperation_t trans; // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index cbde8ee81446..49372405d899 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -521,7 +521,7 @@ void foo() { // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); @@ -530,9 +530,9 @@ void foo() { cublasOperation_t trans; // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); - // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m); cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); From a93d4f4d12426acc083b872f143ea44a1c38ca43 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 23 Sep 2024 09:38:45 +0800 Subject: [PATCH 13/21] Fix Signed-off-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/blas_utils.hpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index c5a6813d196e..a9171b50a45e 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -1491,9 +1491,7 @@ inline void gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, sycl::queue q = desc_ptr->get_queue(); oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; library_data_t scaling_type; - if (auto ct_p = std::get_if(&ct)) { - cm = *ct_p; - } else if (auto ct_p = std::get_if(&ct)) { + if (auto ct_p = std::get_if(&ct)) { cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(), a_type == library_data_t::complex_float || a_type == library_data_t::complex_double); @@ -1997,9 +1995,7 @@ inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, sycl::queue q = desc_ptr->get_queue(); #ifdef __INTEL_MKL__ oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; - if (auto ct_p = std::get_if(&ct)) { - cm = *ct_p; - } else if (auto ct_p = std::get_if(&ct)) { + if (auto ct_p = std::get_if(&ct)) { cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(), a_type == library_data_t::complex_float || a_type == library_data_t::complex_double); From 2d39d506607efccb87bcc836b1e603f0f578d0d7 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 11 Oct 2024 09:18:57 +0800 Subject: [PATCH 14/21] Refine Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/APINamesCUBLAS.inc | 20 ++-- clang/lib/DPCT/MapNames.cpp | 10 +- .../dpct-rt/include/dpct/blas_utils.hpp | 110 +++++++++++++----- clang/test/dpct/cublas-usm-11.cu | 4 +- clang/test/dpct/cublas_64.cu | 4 +- clang/test/dpct/cublas_64_usm.cu | 4 +- 6 files changed, 101 insertions(+), 51 deletions(-) diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index 5405dcc52966..13b0ae743522 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -1735,16 +1735,20 @@ ROTM_EX(cublasRotmEx) ROTM_EX(cublasRotmEx_64) #undef ROTM_EX -#define IAMAXMIN_EX(NAME, IS_MAX) \ +#define IAMAX_EX(NAME) \ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ - #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::iamaxmin<" + \ - #IS_MAX + ">", \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::iamax", \ ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5)))) -IAMAXMIN_EX(cublasIamaxEx, true) -IAMAXMIN_EX(cublasIamaxEx_64, true) -IAMAXMIN_EX(cublasIaminEx, false) -IAMAXMIN_EX(cublasIaminEx_64, false) -#undef IAMAXMIN_EX +IAMAX_EX(cublasIamaxEx) +IAMAX_EX(cublasIamaxEx_64) + +#define IAMIN_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::iamin", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5)))) +IAMIN_EX(cublasIaminEx) +IAMIN_EX(cublasIaminEx_64) +#undef IAMIN_EX #define TBSV(NAME, TYPE) \ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index 42d76a0b1101..c4f89eba62a0 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -2542,16 +2542,14 @@ void MapNames::setExplicitNamespaceMap( {"cublasZtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasCopyEx", getLibraryHelperNamespace() + "blas::copy"}, {"cublasSwapEx", getLibraryHelperNamespace() + "blas::swap"}, - {"cublasIamaxEx", getLibraryHelperNamespace() + "blas::iamaxmin"}, - {"cublasIaminEx", getLibraryHelperNamespace() + "blas::iamaxmin"}, + {"cublasIamaxEx", getLibraryHelperNamespace() + "blas::iamax"}, + {"cublasIaminEx", getLibraryHelperNamespace() + "blas::iamin"}, {"cublasAsumEx", getLibraryHelperNamespace() + "blas::asum"}, {"cublasRotmEx", getLibraryHelperNamespace() + "blas::rotm"}, {"cublasCopyEx_64", getLibraryHelperNamespace() + "blas::copy"}, {"cublasSwapEx_64", getLibraryHelperNamespace() + "blas::swap"}, - {"cublasIamaxEx_64", - getLibraryHelperNamespace() + "blas::iamaxmin"}, - {"cublasIaminEx_64", - getLibraryHelperNamespace() + "blas::iamaxmin"}, + {"cublasIamaxEx_64", getLibraryHelperNamespace() + "blas::iamax"}, + {"cublasIaminEx_64", getLibraryHelperNamespace() + "blas::iamin"}, {"cublasAsumEx_64", getLibraryHelperNamespace() + "blas::asum"}, {"cublasRotmEx_64", getLibraryHelperNamespace() + "blas::rotm"}, {"cublasSgemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index a9171b50a45e..d66077d2dc96 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -602,21 +602,29 @@ inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, #endif } -template -inline void iamaxmin_impl(sycl::queue &q, std::int64_t n, const void *x, - std::int64_t incx, std::int64_t *res) { +inline void iamax_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, std::int64_t *res) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); #else auto data_x = get_memory(x); auto data_res = get_memory(res); - if constexpr (is_max) - oneapi::mkl::blas::column_major::iamax(q, n, data_x, incx, data_res, - oneapi::mkl::index_base::one); - else - oneapi::mkl::blas::column_major::iamin(q, n, data_x, incx, data_res, - oneapi::mkl::index_base::one); + oneapi::mkl::blas::column_major::iamax(q, n, data_x, incx, data_res, + oneapi::mkl::index_base::one); +#endif +} + +inline void iamin_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, std::int64_t *res) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + auto data_x = get_memory(x); + auto data_res = get_memory(res); + oneapi::mkl::blas::column_major::iamin(q, n, data_x, incx, data_res, + oneapi::mkl::index_base::one); #endif } @@ -2636,39 +2644,69 @@ inline void asum(descriptor_ptr desc_ptr, std::int64_t n, const void *x, } } -/// Finds the index of the element with the largest/smallest absolute value in a -/// vector. -/// \tparam is_max True means finding the the largest absolute value index. +/// Finds the index of the element with the largest absolute value in a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the maximal element. +inline void iamax(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, + std::int64_t *result) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float): { + ::dpct::detail::iamax_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double): { + ::dpct::detail::iamax_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float): { + ::dpct::detail::iamax_impl>(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double): { + ::dpct::detail::iamax_impl>(q, n, x, incx, result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Finds the index of the element with the smallest absolute value in a vector. /// \param [in] desc_ptr Descriptor. /// \param [in] n Number of elements in vector x. /// \param [in] x Input vector x. /// \param [in] x_type Data type of the vector x. /// \param [in] incx Stride of vector x. -/// \param [out] result The index of the maximal/minimum element. -template -inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, - library_data_t x_type, std::int64_t incx, - std::int64_t *result) { +/// \param [out] result The index of the minimum element. +inline void iamin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, + std::int64_t *result) { sycl::queue q = desc_ptr->get_queue(); std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type); switch (key) { case ::dpct::detail::get_type_combination_id(library_data_t::real_float): { - ::dpct::detail::iamaxmin_impl(q, n, x, incx, result); + ::dpct::detail::iamin_impl(q, n, x, incx, result); break; } case ::dpct::detail::get_type_combination_id(library_data_t::real_double): { - ::dpct::detail::iamaxmin_impl(q, n, x, incx, result); + ::dpct::detail::iamin_impl(q, n, x, incx, result); break; } case ::dpct::detail::get_type_combination_id(library_data_t::complex_float): { - ::dpct::detail::iamaxmin_impl, is_max>(q, n, x, incx, - result); + ::dpct::detail::iamin_impl>(q, n, x, incx, result); break; } case ::dpct::detail::get_type_combination_id( library_data_t::complex_double): { - ::dpct::detail::iamaxmin_impl, is_max>(q, n, x, incx, - result); + ::dpct::detail::iamin_impl>(q, n, x, incx, result); break; } default: @@ -2676,20 +2714,30 @@ inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, } } -/// Finds the index of the element with the largest/smallest absolute value in a -/// vector. -/// \tparam is_max True means finding the the largest absolute value index. +/// Finds the index of the element with the largest absolute value in a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the maximal element. +inline void iamax(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, int *result) { + dpct::blas::wrapper_int_to_int64_out wrapper(desc_ptr->get_queue(), result); + iamax(desc_ptr, n, x, x_type, incx, wrapper.get_ptr()); +} + +/// Finds the index of the element with the smallest absolute value in a vector. /// \param [in] desc_ptr Descriptor. /// \param [in] n Number of elements in vector x. /// \param [in] x Input vector x. /// \param [in] x_type Data type of the vector x. /// \param [in] incx Stride of vector x. -/// \param [out] result The index of the maximal/minimum element. -template -inline void iamaxmin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, - library_data_t x_type, std::int64_t incx, int *result) { +/// \param [out] result The index of the minimum element. +inline void iamin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, int *result) { dpct::blas::wrapper_int_to_int64_out wrapper(desc_ptr->get_queue(), result); - iamaxmin(desc_ptr, n, x, x_type, incx, wrapper.get_ptr()); + iamin(desc_ptr, n, x, x_type, incx, wrapper.get_ptr()); } /// Finds the least squares solutions for a batch of overdetermined linear diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index 840292ad2744..c0273881e833 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -64,8 +64,8 @@ void foo4() { void *param; // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); - // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); - // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamax(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamin(handle, n, x, dpct::library_data_t::real_float, incx, idx); // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); cublasCopyEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index 9e919b841195..0ed93e69a664 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -651,8 +651,8 @@ void foo2() { void *param; // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); - // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); - // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamax(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamin(handle, n, x, dpct::library_data_t::real_float, incx, idx); // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); cublasCopyEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index 3693468fdeb7..e90621a3847f 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -653,8 +653,8 @@ void foo2() { void *param; // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); - // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); - // CHECK-NEXT: dpct::blas::iamaxmin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamax(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamin(handle, n, x, dpct::library_data_t::real_float, incx, idx); // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); cublasCopyEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); From e46a78479dab26bdfad245de3285a63fe2f5dca6 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 11 Oct 2024 12:47:08 +0800 Subject: [PATCH 15/21] Fix helper Signed-off-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/blas_utils.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index d66077d2dc96..b2f5eaca1b7d 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -602,6 +602,7 @@ inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, #endif } +template inline void iamax_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, std::int64_t *res) { #ifndef __INTEL_MKL__ @@ -615,6 +616,7 @@ inline void iamax_impl(sycl::queue &q, std::int64_t n, const void *x, #endif } +template inline void iamin_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, std::int64_t *res) { #ifndef __INTEL_MKL__ From c95361822553e33789619a197f2fe2d770093b67 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 11 Oct 2024 15:58:48 +0800 Subject: [PATCH 16/21] Remove macro guard Signed-off-by: Jiang, Zhiwei --- .../dpct-rt/include/dpct/blas_utils.hpp | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index b2f5eaca1b7d..7a2ba0fd221c 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -386,10 +386,6 @@ class working_memory { template inline void nrm2_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, void *result) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else #ifdef DPCT_USM_LEVEL_NONE auto x_buffer = dpct::get_buffer(x); auto r_buffer = @@ -402,17 +398,12 @@ inline void nrm2_impl(sycl::queue &q, std::int64_t n, const void *x, oneapi::mkl::blas::column_major::nrm2(q, n, reinterpret_cast(x), incx, res_mem.get_ptr()); #endif -#endif } template inline void dotuc_impl(sycl::queue &q, std::int64_t n, const Txy *x, std::int64_t incx, const Txy *y, std::int64_t incy, Tr *result) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else #ifdef DPCT_USM_LEVEL_NONE auto x_buffer = dpct::get_buffer(x); auto y_buffer = dpct::get_buffer(y); @@ -444,7 +435,6 @@ inline void dotuc_impl(sycl::queue &q, std::int64_t n, const Txy *x, oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, res_mem.get_ptr()); #endif -#endif } template @@ -506,128 +496,83 @@ inline void dotuc(sycl::queue &q, std::int64_t n, const void *x, template inline void scal_impl(sycl::queue &q, std::int64_t n, const void *alpha, void *x, std::int64_t incx) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); auto data_x = get_memory(x); oneapi::mkl::blas::column_major::scal(q, n, alpha_val, data_x, incx); -#endif } template inline void axpy_impl(sycl::queue &q, std::int64_t n, const void *alpha, const void *x, std::int64_t incx, void *y, std::int64_t incy) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); auto data_x = get_memory(x); auto data_y = get_memory(y); oneapi::mkl::blas::column_major::axpy(q, n, alpha_val, data_x, incx, data_y, incy); -#endif } template inline void rot_impl(sycl::queue &q, std::int64_t n, void *x, std::int64_t incx, void *y, std::int64_t incy, const void *c, const void *s) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else Tc c_value = dpct::get_value(reinterpret_cast(c), q); Ts s_value = dpct::get_value(reinterpret_cast(s), q); auto data_x = get_memory(x); auto data_y = get_memory(y); oneapi::mkl::blas::column_major::rot(q, n, data_x, incx, data_y, incy, c_value, s_value); -#endif } template inline void rotm_impl(sycl::queue &q, std::int64_t n, void *x, int64_t incx, void *y, int64_t incy, const void *param) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else auto data_x = get_memory(x); auto data_y = get_memory(y); auto data_param = get_memory(param); oneapi::mkl::blas::column_major::rotm(q, n, data_x, incx, data_y, incy, data_param); -#endif } template inline void copy_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, void *y, std::int64_t incy) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else auto data_x = get_memory(x); auto data_y = get_memory(y); oneapi::mkl::blas::column_major::copy(q, n, data_x, incx, data_y, incy); -#endif } template inline void swap_impl(sycl::queue &q, std::int64_t n, void *x, std::int64_t incx, void *y, std::int64_t incy) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else auto data_x = get_memory(x); auto data_y = get_memory(y); oneapi::mkl::blas::column_major::swap(q, n, data_x, incx, data_y, incy); -#endif } template inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, void *res) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else auto data_x = get_memory(x); auto data_res = get_memory(res); oneapi::mkl::blas::column_major::asum(q, n, data_x, incx, data_res); -#endif } template inline void iamax_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, std::int64_t *res) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else auto data_x = get_memory(x); auto data_res = get_memory(res); oneapi::mkl::blas::column_major::iamax(q, n, data_x, incx, data_res, oneapi::mkl::index_base::one); -#endif } template inline void iamin_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, std::int64_t *res) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else auto data_x = get_memory(x); auto data_res = get_memory(res); oneapi::mkl::blas::column_major::iamin(q, n, data_x, incx, data_res, oneapi::mkl::index_base::one); -#endif } #ifdef __INTEL_MKL__ From cde4bae0783194e0f892495fea7c67aa5053fcca Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 14 Oct 2024 08:26:52 +0800 Subject: [PATCH 17/21] Add guard for min and max Signed-off-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/blas_utils.hpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 7a2ba0fd221c..bf505b7cc9c4 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -560,19 +560,29 @@ inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, template inline void iamax_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, std::int64_t *res) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else auto data_x = get_memory(x); auto data_res = get_memory(res); oneapi::mkl::blas::column_major::iamax(q, n, data_x, incx, data_res, oneapi::mkl::index_base::one); +#endif } template inline void iamin_impl(sycl::queue &q, std::int64_t n, const void *x, std::int64_t incx, std::int64_t *res) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else auto data_x = get_memory(x); auto data_res = get_memory(res); oneapi::mkl::blas::column_major::iamin(q, n, data_x, incx, data_res, oneapi::mkl::index_base::one); +#endif } #ifdef __INTEL_MKL__ From b11e1999ad9855c70b2dc240bcac4b977d069905 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 14 Oct 2024 08:37:31 +0800 Subject: [PATCH 18/21] Update Signed-off-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/memory.hpp | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/memory.hpp b/clang/runtime/dpct-rt/include/dpct/memory.hpp index ce8aa699cc81..be5950d61d72 100644 --- a/clang/runtime/dpct-rt/include/dpct/memory.hpp +++ b/clang/runtime/dpct-rt/include/dpct/memory.hpp @@ -915,7 +915,7 @@ static buffer_t get_buffer(const void *ptr) { } /// A wrapper class contains an accessor and an offset. -template class access_wrapper { sycl::accessor accessor; @@ -931,11 +931,17 @@ class access_wrapper { auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); offset = (byte_t *)ptr - alloc.alloc_ptr; } + template + access_wrapper( + PtrT ptr, sycl::handler &cgh, + typename std::enable_if_t>, void *>> * = 0) + : access_wrapper((const void *)ptr, cgh) {} /// Get the device pointer. /// /// \returns a device pointer with offset. - dataT get_raw_pointer() const { return (dataT)(&accessor[0] + offset); } + PtrT get_raw_pointer() const { return (PtrT)(&accessor[0] + offset); } }; /// Get the accessor for memory pointed by \p ptr. @@ -944,12 +950,17 @@ class access_wrapper { /// If NULL is passed as an argument, an exception will be thrown. /// \param cgh The command group handler. /// \returns an accessor. -template -static sycl::accessor -get_access(const void *ptr, sycl::handler &cgh) { +template +static auto get_access(const T *ptr, sycl::handler &cgh) { if (ptr) { auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); - return alloc.buffer.get_access(cgh); + if constexpr (std::is_same_v, void>) + return alloc.buffer.template get_access(cgh); + else + return alloc.buffer + .template reinterpret(sycl::range<1>(alloc.size / sizeof(T))) + .template get_access(cgh); } else { throw std::runtime_error( "NULL pointer argument in get_access function is invalid"); From 7a9804f58a69e82edadf073241e0750bb9de1d46 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 14 Oct 2024 08:45:29 +0800 Subject: [PATCH 19/21] Revert "Update" This reverts commit b11e1999ad9855c70b2dc240bcac4b977d069905. --- clang/runtime/dpct-rt/include/dpct/memory.hpp | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/memory.hpp b/clang/runtime/dpct-rt/include/dpct/memory.hpp index be5950d61d72..ce8aa699cc81 100644 --- a/clang/runtime/dpct-rt/include/dpct/memory.hpp +++ b/clang/runtime/dpct-rt/include/dpct/memory.hpp @@ -915,7 +915,7 @@ static buffer_t get_buffer(const void *ptr) { } /// A wrapper class contains an accessor and an offset. -template class access_wrapper { sycl::accessor accessor; @@ -931,17 +931,11 @@ class access_wrapper { auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); offset = (byte_t *)ptr - alloc.alloc_ptr; } - template - access_wrapper( - PtrT ptr, sycl::handler &cgh, - typename std::enable_if_t>, void *>> * = 0) - : access_wrapper((const void *)ptr, cgh) {} /// Get the device pointer. /// /// \returns a device pointer with offset. - PtrT get_raw_pointer() const { return (PtrT)(&accessor[0] + offset); } + dataT get_raw_pointer() const { return (dataT)(&accessor[0] + offset); } }; /// Get the accessor for memory pointed by \p ptr. @@ -950,17 +944,12 @@ class access_wrapper { /// If NULL is passed as an argument, an exception will be thrown. /// \param cgh The command group handler. /// \returns an accessor. -template -static auto get_access(const T *ptr, sycl::handler &cgh) { +template +static sycl::accessor +get_access(const void *ptr, sycl::handler &cgh) { if (ptr) { auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); - if constexpr (std::is_same_v, void>) - return alloc.buffer.template get_access(cgh); - else - return alloc.buffer - .template reinterpret(sycl::range<1>(alloc.size / sizeof(T))) - .template get_access(cgh); + return alloc.buffer.get_access(cgh); } else { throw std::runtime_error( "NULL pointer argument in get_access function is invalid"); From d9367bc21bdff91d60d8f6bf8293e7bf8756d4e1 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Mon, 14 Oct 2024 11:18:38 +0800 Subject: [PATCH 20/21] Fix helper Signed-off-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/blas_utils.hpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index bf505b7cc9c4..841c63570b7b 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -479,6 +479,7 @@ inline void dotuc(sycl::queue &q, std::int64_t n, const void *x, reinterpret_cast *>(result)); break; } +#ifdef __INTEL_MKL__ case detail::get_type_combination_id(library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): { @@ -488,6 +489,7 @@ inline void dotuc(sycl::queue &q, std::int64_t n, const void *x, reinterpret_cast(result)); break; } +#endif default: throw std::runtime_error("the combination of data type is unsupported"); } @@ -528,7 +530,7 @@ inline void rotm_impl(sycl::queue &q, std::int64_t n, void *x, int64_t incx, void *y, int64_t incy, const void *param) { auto data_x = get_memory(x); auto data_y = get_memory(y); - auto data_param = get_memory(param); + auto data_param = get_memory(const_cast(param)); oneapi::mkl::blas::column_major::rotm(q, n, data_x, incx, data_y, incy, data_param); } @@ -2213,11 +2215,13 @@ inline void nrm2(descriptor_ptr desc_ptr, std::int64_t n, const void *x, result); break; } +#ifdef __INTEL_MKL__ case ::dpct::detail::get_type_combination_id(library_data_t::real_half, library_data_t::real_half): { ::dpct::detail::nrm2_impl(q, n, x, incx, result); break; } +#endif default: throw std::runtime_error("the combination of data type is unsupported"); } @@ -2296,6 +2300,7 @@ inline void scal(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, q, n, alpha, x, incx); break; } +#ifdef __INTEL_MKL__ case ::dpct::detail::get_type_combination_id(library_data_t::real_half): { float alpha_value = dpct::get_value(reinterpret_cast(alpha), q); @@ -2304,6 +2309,7 @@ inline void scal(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, incx); break; } +#endif default: throw std::runtime_error("the combination of data type is unsupported"); } @@ -2350,8 +2356,9 @@ inline void axpy(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, q, n, alpha, x, incx, y, incy); break; } +#ifdef __INTEL_MKL__ case ::dpct::detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_float): { + library_data_t::real_half): { float alpha_value = dpct::get_value(reinterpret_cast(alpha), q); sycl::half alaph_half(alpha_value); @@ -2359,6 +2366,7 @@ inline void axpy(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, incx, y, incy); break; } +#endif default: throw std::runtime_error("the combination of data type is unsupported"); } @@ -2406,6 +2414,7 @@ inline void rot(descriptor_ptr desc_ptr, std::int64_t n, void *x, q, n, x, incx, y, incy, c, s); break; } +#ifdef __INTEL_MKL__ case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, library_data_t::complex_float): { ::dpct::detail::rot_impl, float, std::complex>( @@ -2432,6 +2441,7 @@ inline void rot(descriptor_ptr desc_ptr, std::int64_t n, void *x, s); break; } +#endif default: throw std::runtime_error("the combination of data type is unsupported"); } From 6f31fa1b81e263ca6f009f514aea551532991313 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Tue, 15 Oct 2024 09:39:57 +0800 Subject: [PATCH 21/21] Fix case Signed-off-by: Jiang, Zhiwei --- clang/runtime/dpct-rt/include/dpct/blas_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 841c63570b7b..5fe6c533d1a1 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -2358,7 +2358,7 @@ inline void axpy(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, } #ifdef __INTEL_MKL__ case ::dpct::detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { + library_data_t::real_float): { float alpha_value = dpct::get_value(reinterpret_cast(alpha), q); sycl::half alaph_half(alpha_value);