diff --git a/clang/lib/DPCT/APINamesCUBLAS.inc b/clang/lib/DPCT/APINamesCUBLAS.inc index 0e7ac16938b2..13b0ae743522 100644 --- a/clang/lib/DPCT/APINamesCUBLAS.inc +++ b/clang/lib/DPCT/APINamesCUBLAS.inc @@ -127,58 +127,107 @@ GEMM(cublasZgemm_v2_64, "std::complex", true) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY("cublasNrm2Ex", - CALL(MapNames::getLibraryHelperNamespace() + "nrm2", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) + CALL_FACTORY_ENTRY( + "cublasNrm2Ex", + CALL(MapNames::getLibraryHelperNamespace() + "blas::nrm2", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY( + "cublasNrm2Ex_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::nrm2", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasDotEx", - CALL(MapNames::getLibraryHelperNamespace() + "dot", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9))))) + CALL(MapNames::getLibraryHelperNamespace() + "blas::dot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasDotEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::dot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasDotcEx", - CALL(MapNames::getLibraryHelperNamespace() + "dotc", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9))))) + CALL(MapNames::getLibraryHelperNamespace() + + "blas::dotc", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasDotcEx_64", + CALL(MapNames::getLibraryHelperNamespace() + + "blas::dotc", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY("cublasScalEx", - CALL(MapNames::getLibraryHelperNamespace() + "scal", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) + CALL_FACTORY_ENTRY( + "cublasScalEx", + CALL(MapNames::getLibraryHelperNamespace() + "blas::scal", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY( + "cublasScalEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::scal", ARG(0), + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasAxpyEx", - CALL(MapNames::getLibraryHelperNamespace() + "axpy", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9))))) + CALL(MapNames::getLibraryHelperNamespace() + + "blas::axpy", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasAxpyEx_64", + CALL(MapNames::getLibraryHelperNamespace() + + "blas::axpy", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY("cublasRotEx", - CALL(MapNames::getLibraryHelperNamespace() + "rot", MEMBER_CALL(ARG(0), true, "get_queue"), - ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), - ARG(7), ARG(8), ARG(9), ARG(10))))) + CALL(MapNames::getLibraryHelperNamespace() + "blas::rot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9), ARG(10))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("cublasRotEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::rot", + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), + ARG(6), ARG(7), ARG(8), ARG(9), ARG(10))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY( - "cublasGemmEx", - CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), + "cublasGemmBatchedEx", + CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm_batch", ARG(0), BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), - ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17))))) - + ARG(3), ARG(4), ARG(5), ARG(6), + DOUBLE_POINTER_CONST_CAST(makeLiteral("void"), ARG(7), BOOL(true), + BOOL(false)), + ARG(8), ARG(9), + DOUBLE_POINTER_CONST_CAST(makeLiteral("void"), ARG(10), BOOL(true), + BOOL(false)), + ARG(11), ARG(12), ARG(13), + DOUBLE_POINTER_CONST_CAST(makeLiteral("void"), ARG(14), + BOOL(false), BOOL(false)), + ARG(15), ARG(16), ARG(17), ARG(18))))) ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, CALL_FACTORY_ENTRY( - "cublasGemmBatchedEx", + "cublasGemmBatchedEx_64", CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm_batch", ARG(0), BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), @@ -203,6 +252,16 @@ ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), ARG(18), ARG(19), ARG(20), ARG(21))))) +ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY( + "cublasGemmStridedBatchedEx_64", + CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm_batch", ARG(0), + BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), + BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), + ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), + ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), ARG(17), + ARG(18), ARG(19), ARG(20), ARG(21))))) #define SYRKX(FUNC) \ ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( \ @@ -658,29 +717,65 @@ WARNING_FACTORY_ENTRY( Diagnostics::TRNA_WARNING_ERROR_HANDLING_API_COMMENTED, ARG("The call was replaced by a placeholder string")) -ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( - HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY( - "cublasSgemmEx", - CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), - BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), - ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), - ARG(MapNames::getLibraryHelperNamespace() + - "library_data_t::real_float"))))) - -ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( - HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY( - "cublasCgemmEx", - CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), - BLAS_ENUM_ARG(1, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - BLAS_ENUM_ARG(2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), - ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), ARG(10), - ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), - ARG(MapNames::getLibraryHelperNamespace() + - "library_data_t::complex_float"))))) +#define GEMM_EX(NAME, COMPUTE_TYPE) \ + ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( \ + HelperFeatureEnum::device_ext, \ + CALL_FACTORY_ENTRY( \ + #NAME, \ + CALL(MapNames::getLibraryHelperNamespace() + "blas::gemm", ARG(0), \ + BLAS_ENUM_ARG(1, \ + clang::dpct::BLASEnumExpr::BLASEnumType::Trans), \ + BLAS_ENUM_ARG(2, \ + clang::dpct::BLASEnumExpr::BLASEnumType::Trans), \ + ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), \ + ARG(10), ARG(11), ARG(12), ARG(13), ARG(14), ARG(15), ARG(16), \ + ARG(COMPUTE_TYPE))))) +GEMM_EX(cublasSgemmEx, + MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") +GEMM_EX(cublasCgemmEx, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +GEMM_EX(cublasCgemm3mEx, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +GEMM_EX(cublasGemmEx, 17) +GEMM_EX(cublasSgemmEx_64, + MapNames::getLibraryHelperNamespace() + "library_data_t::real_float") +GEMM_EX(cublasCgemmEx_64, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +GEMM_EX(cublasCgemm3mEx_64, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +GEMM_EX(cublasGemmEx_64, 17) +#undef GEMM_EX + +#define SYHERK(NAME, IS_HERMITIAN, COMPUTE_TYPE) \ + ASSIGNABLE_FACTORY(FEATURE_REQUEST_FACTORY( \ + HelperFeatureEnum::device_ext, \ + CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + \ + "blas::syherk<" + #IS_HERMITIAN + ">", \ + ARG(0), \ + BLAS_ENUM_ARG( \ + 1, clang::dpct::BLASEnumExpr::BLASEnumType::Uplo), \ + BLAS_ENUM_ARG( \ + 2, clang::dpct::BLASEnumExpr::BLASEnumType::Trans), \ + ARG(3), ARG(4), ARG(5), ARG(6), ARG(7), ARG(8), ARG(9), \ + ARG(10), ARG(11), ARG(12), ARG(COMPUTE_TYPE))))) +SYHERK(cublasCsyrkEx, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCsyrk3mEx, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCherkEx, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCherk3mEx, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCsyrkEx_64, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCsyrk3mEx_64, false, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCherkEx_64, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +SYHERK(cublasCherk3mEx_64, true, + MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float") +#undef SYHERK #define SYRK(NAME, TYPE, IS_COMPLEX) \ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ @@ -1605,6 +1700,56 @@ GEMM_BATCH(cublasCgemmStridedBatched, "std::complex", true) GEMM_BATCH(cublasZgemmStridedBatched, "std::complex", true) #undef GEMM_BATCH +#define COPY_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, \ + CALL(MapNames::getLibraryHelperNamespace() + "blas::copy", ARG(0), \ + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7)))) +COPY_EX(cublasCopyEx) +COPY_EX(cublasCopyEx_64) +#undef COPY_EX + +#define SWAP_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, \ + CALL(MapNames::getLibraryHelperNamespace() + "blas::swap", ARG(0), \ + ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), ARG(7)))) +SWAP_EX(cublasSwapEx) +SWAP_EX(cublasSwapEx_64) +#undef SWAP_EX + +#define ASUM_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::asum", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6)))) +ASUM_EX(cublasAsumEx) +ASUM_EX(cublasAsumEx_64) +#undef ASUM_EX + +#define ROTM_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::rotm", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5), ARG(6), \ + ARG(7), ARG(8), ARG(9)))) +ROTM_EX(cublasRotmEx) +ROTM_EX(cublasRotmEx_64) +#undef ROTM_EX + +#define IAMAX_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::iamax", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5)))) +IAMAX_EX(cublasIamaxEx) +IAMAX_EX(cublasIamaxEx_64) + +#define IAMIN_EX(NAME) \ + ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ + #NAME, CALL(MapNames::getLibraryHelperNamespace() + "blas::iamin", \ + ARG(0), ARG(1), ARG(2), ARG(3), ARG(4), ARG(5)))) +IAMIN_EX(cublasIaminEx) +IAMIN_EX(cublasIaminEx_64) +#undef IAMIN_EX + #define TBSV(NAME, TYPE) \ ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY( \ #NAME, \ diff --git a/clang/lib/DPCT/APINames_cuBLAS.inc b/clang/lib/DPCT/APINames_cuBLAS.inc index 7154930ac2e4..66f5d64cfd48 100644 --- a/clang/lib/DPCT/APINames_cuBLAS.inc +++ b/clang/lib/DPCT/APINames_cuBLAS.inc @@ -279,10 +279,10 @@ ENTRY(cublasCgemmEx, cublasCgemmEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasGemmEx, cublasGemmEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasGemmBatchedEx, cublasGemmBatchedEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasGemmStridedBatchedEx, cublasGemmStridedBatchedEx, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasCsyrkEx, cublasCsyrkEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCsyrk3mEx, cublasCsyrk3mEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherkEx, cublasCherkEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherk3mEx, cublasCherk3mEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasCsyrkEx, cublasCsyrkEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCsyrk3mEx, cublasCsyrk3mEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherkEx, cublasCherkEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherk3mEx, cublasCherk3mEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasNrm2Ex, cublasNrm2Ex, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasAxpyEx, cublasAxpyEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDotEx, cublasDotEx, true, NO_FLAG, P4, "DPCT1020") @@ -525,19 +525,19 @@ ENTRY(cublasDtrmm, cublasDtrmm, true, NO_FLAG, P4, "Successful") ENTRY(cublasCtrmm, cublasCtrmm, true, NO_FLAG, P4, "Successful") ENTRY(cublasZtrmm, cublasZtrmm, true, NO_FLAG, P4, "Successful") -ENTRY(cublasAsumEx, cublasAsumEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasAsumEx, cublasAsumEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCgemm3mBatched, cublasCgemm3mBatched, false, NO_FLAG, P4, "comment") -ENTRY(cublasCgemm3mEx, cublasCgemm3mEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasCopyEx, cublasCopyEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasCgemm3mEx, cublasCgemm3mEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCopyEx, cublasCopyEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDotcEx, cublasDotcEx, true, NO_FLAG, P4, "Successful") ENTRY(cublasGetCudartVersion, cublasGetCudartVersion, false, NO_FLAG, P4, "comment") -ENTRY(cublasIamaxEx, cublasIamaxEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasIaminEx, cublasIaminEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasIamaxEx, cublasIamaxEx, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasIaminEx, cublasIaminEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasRotEx, cublasRotEx, true, NO_FLAG, P4, "Successful") ENTRY(cublasRotgEx, cublasRotgEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasRotmEx, cublasRotmEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasRotmEx, cublasRotmEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasRotmgEx, cublasRotmgEx, false, NO_FLAG, P4, "comment") -ENTRY(cublasSwapEx, cublasSwapEx, false, NO_FLAG, P4, "comment") +ENTRY(cublasSwapEx, cublasSwapEx, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasUint8gemmBias, cublasUint8gemmBias, false, NO_FLAG, P4, "comment") ENTRY(cublasXerbla, cublasXerbla, false, NO_FLAG, P4, "comment") ENTRY(cublasXtGetNumBoards, cublasXtGetNumBoards, false, NO_FLAG, P4, "comment") @@ -745,18 +745,18 @@ ENTRY(cublasSetVectorAsync_64, cublasSetVectorAsync_64, true, NO_FLAG, P4, "DPCT ENTRY(cublasGetVectorAsync_64, cublasGetVectorAsync_64, true, NO_FLAG, P4, "DPCT1018/DPCT1020") ENTRY(cublasSetMatrixAsync_64, cublasSetMatrixAsync_64, true, NO_FLAG, P4, "DPCT1018/DPCT1020") ENTRY(cublasGetMatrixAsync_64, cublasGetMatrixAsync_64, true, NO_FLAG, P4, "DPCT1018/DPCT1020") -ENTRY(cublasNrm2Ex_64, cublasNrm2Ex_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasDotEx_64, cublasDotEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasDotcEx_64, cublasDotcEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasScalEx_64, cublasScalEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasAxpyEx_64, cublasAxpyEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCopyEx_64, cublasCopyEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasSwapEx_64, cublasSwapEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasIamaxEx_64, cublasIamaxEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasIaminEx_64, cublasIaminEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasAsumEx_64, cublasAsumEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasRotEx_64, cublasRotEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasRotmEx_64, cublasRotmEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasNrm2Ex_64, cublasNrm2Ex_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasDotEx_64, cublasDotEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasDotcEx_64, cublasDotcEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasScalEx_64, cublasScalEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasAxpyEx_64, cublasAxpyEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCopyEx_64, cublasCopyEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasSwapEx_64, cublasSwapEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasIamaxEx_64, cublasIamaxEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasIaminEx_64, cublasIaminEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasAsumEx_64, cublasAsumEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasRotEx_64, cublasRotEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasRotmEx_64, cublasRotmEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasSgemvBatched_64, cublasSgemvBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasDgemvBatched_64, cublasDgemvBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCgemvBatched_64, cublasCgemvBatched_64, false, NO_FLAG, P4, "comment") @@ -774,16 +774,16 @@ ENTRY(cublasHSSgemvStridedBatched_64, cublasHSSgemvStridedBatched_64, false, NO_ ENTRY(cublasTSTgemvStridedBatched_64, cublasTSTgemvStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasTSSgemvStridedBatched_64, cublasTSSgemvStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCgemm3m_64, cublasCgemm3m_64, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasCgemm3mEx_64, cublasCgemm3mEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasCgemm3mEx_64, cublasCgemm3mEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasZgemm3m_64, cublasZgemm3m_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasHgemm_64, cublasHgemm_64, true, NO_FLAG, P4, "DPCT1020") -ENTRY(cublasSgemmEx_64, cublasSgemmEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasGemmEx_64, cublasGemmEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCgemmEx_64, cublasCgemmEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCsyrkEx_64, cublasCsyrkEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCsyrk3mEx_64, cublasCsyrk3mEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherkEx_64, cublasCherkEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasCherk3mEx_64, cublasCherk3mEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasSgemmEx_64, cublasSgemmEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasGemmEx_64, cublasGemmEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCgemmEx_64, cublasCgemmEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCsyrkEx_64, cublasCsyrkEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCsyrk3mEx_64, cublasCsyrk3mEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherkEx_64, cublasCherkEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasCherk3mEx_64, cublasCherk3mEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasSsyrkx_64, cublasSsyrkx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasDsyrkx_64, cublasDsyrkx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasCsyrkx_64, cublasCsyrkx_64, true, NO_FLAG, P4, "DPCT1020") @@ -802,8 +802,8 @@ ENTRY(cublasDgemmStridedBatched_64, cublasDgemmStridedBatched_64, false, NO_FLAG ENTRY(cublasCgemmStridedBatched_64, cublasCgemmStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasCgemm3mStridedBatched_64, cublasCgemm3mStridedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasZgemmStridedBatched_64, cublasZgemmStridedBatched_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasGemmBatchedEx_64, cublasGemmBatchedEx_64, false, NO_FLAG, P4, "comment") -ENTRY(cublasGemmStridedBatchedEx_64, cublasGemmStridedBatchedEx_64, false, NO_FLAG, P4, "comment") +ENTRY(cublasGemmBatchedEx_64, cublasGemmBatchedEx_64, true, NO_FLAG, P4, "DPCT1020") +ENTRY(cublasGemmStridedBatchedEx_64, cublasGemmStridedBatchedEx_64, true, NO_FLAG, P4, "DPCT1020") ENTRY(cublasSgemmGroupedBatched, cublasSgemmGroupedBatched, false, NO_FLAG, P4, "comment") ENTRY(cublasSgemmGroupedBatched_64, cublasSgemmGroupedBatched_64, false, NO_FLAG, P4, "comment") ENTRY(cublasDgemmGroupedBatched, cublasDgemmGroupedBatched, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 6224f1183290..ed3ba6eea33d 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -4186,11 +4186,14 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasSgeqrfBatched", "cublasDgeqrfBatched", "cublasCgeqrfBatched", "cublasZgeqrfBatched", "cublasSgelsBatched", "cublasDgelsBatched", "cublasCgelsBatched", "cublasZgelsBatched", "cublasGemmEx", - "cublasSgemmEx", "cublasCgemmEx", "cublasNrm2Ex", "cublasDotEx", - "cublasDotcEx", "cublasScalEx", "cublasAxpyEx", "cublasRotEx", - "cublasGemmBatchedEx", "cublasGemmStridedBatchedEx", "cublasSdgmm", - "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", "cublasSgeam", - "cublasDgeam", "cublasCgeam", "cublasZgeam", + "cublasSgemmEx", "cublasCgemmEx", "cublasCgemm3mEx", "cublasNrm2Ex", + "cublasDotEx", "cublasDotcEx", "cublasScalEx", "cublasAxpyEx", + "cublasRotEx", "cublasGemmBatchedEx", "cublasGemmStridedBatchedEx", + "cublasSdgmm", "cublasDdgmm", "cublasCdgmm", "cublasZdgmm", + "cublasSgeam", "cublasDgeam", "cublasCgeam", "cublasZgeam", + "cublasCopyEx", "cublasSwapEx", "cublasIamaxEx", "cublasIaminEx", + "cublasAsumEx", "cublasRotmEx", "cublasCsyrkEx", "cublasCsyrk3mEx", + "cublasCherkEx", "cublasCherk3mEx", /*Legacy API*/ "cublasInit", "cublasShutdown", "cublasGetError", "cublasSetKernelStream", "cublasGetVersion", @@ -4299,6 +4302,15 @@ void BLASFunctionCallRule::registerMatcher(MatchFinder &MF) { "cublasCsyrkx_64", "cublasZsyrkx_64", "cublasCherkx_64", "cublasZherkx_64", "cublasHgemm_64", "cublasCgemm3m_64", "cublasZgemm3m_64", + /*extension*/ + "cublasNrm2Ex_64", "cublasDotEx_64", "cublasDotcEx_64", + "cublasScalEx_64", "cublasAxpyEx_64", "cublasRotEx_64", + "cublasGemmBatchedEx_64", "cublasGemmStridedBatchedEx_64", + "cublasCopyEx_64", "cublasSwapEx_64", "cublasIamaxEx_64", + "cublasIaminEx_64", "cublasAsumEx_64", "cublasRotmEx_64", + "cublasSgemmEx_64", "cublasCgemmEx_64", "cublasCgemm3mEx_64", + "cublasGemmEx_64", "cublasCsyrkEx_64", "cublasCsyrk3mEx_64", + "cublasCherkEx_64", "cublasCherk3mEx_64", /*cublasLt*/ "cublasLtCreate", "cublasLtDestroy", "cublasLtMatmulDescCreate", "cublasLtMatmulDescDestroy", "cublasLtMatmulDescSetAttribute", @@ -4467,7 +4479,8 @@ void BLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) { FuncName == "cublasDgemmBatched" || FuncName == "cublasCgemmBatched" || FuncName == "cublasZgemmBatched" || FuncName == "cublasStrsmBatched" || FuncName == "cublasDtrsmBatched" || FuncName == "cublasCtrsmBatched" || - FuncName == "cublasZtrsmBatched" || FuncName == "cublasGemmBatchedEx")) { + FuncName == "cublasZtrsmBatched" || FuncName == "cublasGemmBatchedEx" || + FuncName == "cublasGemmBatchedEx_64")) { report(FuncNameBegin, Diagnostics::API_NOT_MIGRATED, false, FuncName); return; } diff --git a/clang/lib/DPCT/MapNames.cpp b/clang/lib/DPCT/MapNames.cpp index dd307b7d2f36..c4f89eba62a0 100644 --- a/clang/lib/DPCT/MapNames.cpp +++ b/clang/lib/DPCT/MapNames.cpp @@ -2154,18 +2154,29 @@ void MapNames::setExplicitNamespaceMap( "oneapi::mkl::blas::column_major::gemm_batch"}, {"cublasZgemmStridedBatched", "oneapi::mkl::blas::column_major::gemm_batch"}, - {"cublasNrm2Ex", getLibraryHelperNamespace() + "nrm2_ex"}, - {"cublasDotEx", getLibraryHelperNamespace() + "dot_ex"}, - {"cublasDotcEx", getLibraryHelperNamespace() + "dotc_ex"}, - {"cublasScalEx", getLibraryHelperNamespace() + "scal_ex"}, - {"cublasAxpyEx", getLibraryHelperNamespace() + "axpy_ex"}, - {"cublasRotEx", getLibraryHelperNamespace() + "rot_ex"}, + {"cublasNrm2Ex", getLibraryHelperNamespace() + "blas::nrm2"}, + {"cublasNrm2Ex_64", getLibraryHelperNamespace() + "blas::nrm2"}, + {"cublasDotEx", getLibraryHelperNamespace() + "blas::dot"}, + {"cublasDotEx_64", getLibraryHelperNamespace() + "blas::dot"}, + {"cublasDotcEx", getLibraryHelperNamespace() + "blas::dotc"}, + {"cublasDotcEx_64", getLibraryHelperNamespace() + "blas::dotc"}, + {"cublasScalEx", getLibraryHelperNamespace() + "blas::scal"}, + {"cublasScalEx_64", getLibraryHelperNamespace() + "blas::scal"}, + {"cublasAxpyEx", getLibraryHelperNamespace() + "blas::axpy"}, + {"cublasAxpyEx_64", getLibraryHelperNamespace() + "blas::axpy"}, + {"cublasRotEx", getLibraryHelperNamespace() + "blas::rot"}, + {"cublasRotEx_64", getLibraryHelperNamespace() + "blas::rot"}, {"cublasGemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasSgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasCgemmEx", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCgemm3mEx", getLibraryHelperNamespace() + "blas::gemm"}, {"cublasGemmBatchedEx", getLibraryHelperNamespace() + "blas::gemm_batch"}, + {"cublasGemmBatchedEx_64", + getLibraryHelperNamespace() + "blas::gemm_batch"}, {"cublasGemmStridedBatchedEx", getLibraryHelperNamespace() + "blas::gemm_batch"}, + {"cublasGemmStridedBatchedEx_64", + getLibraryHelperNamespace() + "blas::gemm_batch"}, {"cublasSsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, {"cublasDsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, {"cublasCsyrkx", getLibraryHelperNamespace() + "blas::syrk"}, @@ -2529,6 +2540,32 @@ void MapNames::setExplicitNamespaceMap( {"cublasDtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasCtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, {"cublasZtrmm_v2_64", getLibraryHelperNamespace() + "blas::trmm"}, + {"cublasCopyEx", getLibraryHelperNamespace() + "blas::copy"}, + {"cublasSwapEx", getLibraryHelperNamespace() + "blas::swap"}, + {"cublasIamaxEx", getLibraryHelperNamespace() + "blas::iamax"}, + {"cublasIaminEx", getLibraryHelperNamespace() + "blas::iamin"}, + {"cublasAsumEx", getLibraryHelperNamespace() + "blas::asum"}, + {"cublasRotmEx", getLibraryHelperNamespace() + "blas::rotm"}, + {"cublasCopyEx_64", getLibraryHelperNamespace() + "blas::copy"}, + {"cublasSwapEx_64", getLibraryHelperNamespace() + "blas::swap"}, + {"cublasIamaxEx_64", getLibraryHelperNamespace() + "blas::iamax"}, + {"cublasIaminEx_64", getLibraryHelperNamespace() + "blas::iamin"}, + {"cublasAsumEx_64", getLibraryHelperNamespace() + "blas::asum"}, + {"cublasRotmEx_64", getLibraryHelperNamespace() + "blas::rotm"}, + {"cublasSgemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCgemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCgemm3mEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasGemmEx_64", getLibraryHelperNamespace() + "blas::gemm"}, + {"cublasCsyrkEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCsyrk3mEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherkEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherk3mEx", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCsyrkEx_64", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCsyrk3mEx_64", + getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherkEx_64", getLibraryHelperNamespace() + "blas::syherk"}, + {"cublasCherk3mEx_64", + getLibraryHelperNamespace() + "blas::syherk"}, // cublasLt {"cublasLtCreate", "new " + getLibraryHelperNamespace() + "blas_gemm::experimental::descriptor"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp index 7c34a579229b..5fe6c533d1a1 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_utils.hpp @@ -384,12 +384,8 @@ class working_memory { #endif template -inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx, - void *result) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else +inline void nrm2_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, void *result) { #ifdef DPCT_USM_LEVEL_NONE auto x_buffer = dpct::get_buffer(x); auto r_buffer = @@ -402,16 +398,12 @@ inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx, oneapi::mkl::blas::column_major::nrm2(q, n, reinterpret_cast(x), incx, res_mem.get_ptr()); #endif -#endif } template -inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx, - const Txy *y, int incy, Tr *result) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else +inline void dotuc_impl(sycl::queue &q, std::int64_t n, const Txy *x, + std::int64_t incx, const Txy *y, std::int64_t incy, + Tr *result) { #ifdef DPCT_USM_LEVEL_NONE auto x_buffer = dpct::get_buffer(x); auto y_buffer = dpct::get_buffer(y); @@ -434,41 +426,44 @@ inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx, if constexpr (std::is_same_v> || std::is_same_v>) { if constexpr (is_conjugate) - oneapi::mkl::blas::column_major::dotc(q, n, x, incx, y, incy, res_mem.get_ptr()); + oneapi::mkl::blas::column_major::dotc(q, n, x, incx, y, incy, + res_mem.get_ptr()); else - oneapi::mkl::blas::column_major::dotu(q, n, x, incx, y, incy, res_mem.get_ptr()); + oneapi::mkl::blas::column_major::dotu(q, n, x, incx, y, incy, + res_mem.get_ptr()); } else - oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, res_mem.get_ptr()); -#endif + oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, + res_mem.get_ptr()); #endif } template -inline void dotuc(sycl::queue &q, int n, const void *x, - library_data_t x_type, int incx, const void *y, - library_data_t y_type, int incy, void *result, - library_data_t result_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, y_type, result_type); +inline void dotuc(sycl::queue &q, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, const void *y, + library_data_t y_type, std::int64_t incy, void *result, + library_data_t result_type) { + std::uint64_t key = + detail::get_type_combination_id(x_type, y_type, result_type); switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float): { - detail::dotuc_impl( - q, n, reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + detail::dotuc_impl(q, n, reinterpret_cast(x), + incx, reinterpret_cast(y), + incy, reinterpret_cast(result)); break; } - case detail::get_type_combination_id(library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double): { - detail::dotuc_impl( - q, n, reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + detail::dotuc_impl(q, n, reinterpret_cast(x), + incx, reinterpret_cast(y), + incy, reinterpret_cast(result)); break; } case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float, - library_data_t::complex_float): { + library_data_t::complex_float, + library_data_t::complex_float): { detail::dotuc_impl( q, n, reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy, @@ -476,71 +471,119 @@ inline void dotuc(sycl::queue &q, int n, const void *x, break; } case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double, - library_data_t::complex_double): { + library_data_t::complex_double, + library_data_t::complex_double): { detail::dotuc_impl( q, n, reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy, reinterpret_cast *>(result)); break; } - case detail::get_type_combination_id(library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half): { +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half): { detail::dotuc_impl( q, n, reinterpret_cast(x), incx, reinterpret_cast(y), incy, reinterpret_cast(result)); break; } +#endif default: throw std::runtime_error("the combination of data type is unsupported"); } } template -inline void scal_impl(sycl::queue &q, int n, const void *alpha, void *x, - int incx) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else +inline void scal_impl(sycl::queue &q, std::int64_t n, const void *alpha, + void *x, std::int64_t incx) { Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); auto data_x = get_memory(x); - oneapi::mkl::blas::column_major::scal(q, n, alpha_val, - data_x, incx); -#endif + oneapi::mkl::blas::column_major::scal(q, n, alpha_val, data_x, incx); } template -inline void axpy_impl(sycl::queue &q, int n, const void *alpha, const void *x, - int incx, void *y, int incy) { +inline void axpy_impl(sycl::queue &q, std::int64_t n, const void *alpha, + const void *x, std::int64_t incx, void *y, + std::int64_t incy) { + Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::axpy(q, n, alpha_val, data_x, incx, data_y, + incy); +} + +template +inline void rot_impl(sycl::queue &q, std::int64_t n, void *x, std::int64_t incx, + void *y, std::int64_t incy, const void *c, const void *s) { + Tc c_value = dpct::get_value(reinterpret_cast(c), q); + Ts s_value = dpct::get_value(reinterpret_cast(s), q); + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::rot(q, n, data_x, incx, data_y, incy, + c_value, s_value); +} + +template +inline void rotm_impl(sycl::queue &q, std::int64_t n, void *x, int64_t incx, + void *y, int64_t incy, const void *param) { + auto data_x = get_memory(x); + auto data_y = get_memory(y); + auto data_param = get_memory(const_cast(param)); + oneapi::mkl::blas::column_major::rotm(q, n, data_x, incx, data_y, incy, + data_param); +} + +template +inline void copy_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, void *y, std::int64_t incy) { + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::copy(q, n, data_x, incx, data_y, incy); +} + +template +inline void swap_impl(sycl::queue &q, std::int64_t n, void *x, + std::int64_t incx, void *y, std::int64_t incy) { + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::swap(q, n, data_x, incx, data_y, incy); +} + +template +inline void asum_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, void *res) { + auto data_x = get_memory(x); + auto data_res = get_memory(res); + oneapi::mkl::blas::column_major::asum(q, n, data_x, incx, data_res); +} + +template +inline void iamax_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, std::int64_t *res) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); #else - Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); - auto data_x = get_memory(x); - auto data_y = get_memory(y); - oneapi::mkl::blas::column_major::axpy(q, n, alpha_val, - data_x, incx, - data_y, incy); + auto data_x = get_memory(x); + auto data_res = get_memory(res); + oneapi::mkl::blas::column_major::iamax(q, n, data_x, incx, data_res, + oneapi::mkl::index_base::one); #endif } -template -inline void rot_impl(sycl::queue &q, int n, void *x, int incx, void *y, - int incy, const void *c, const void *s) { +template +inline void iamin_impl(sycl::queue &q, std::int64_t n, const void *x, + std::int64_t incx, std::int64_t *res) { #ifndef __INTEL_MKL__ throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " "Project does not support this API."); #else - Tc c_value = dpct::get_value(reinterpret_cast(c), q); - Ts s_value = dpct::get_value(reinterpret_cast(s), q); - auto data_x = get_memory(x); - auto data_y = get_memory(y); - oneapi::mkl::blas::column_major::rot(q, n, data_x, incx, - data_y, incy, c_value, - s_value); + auto data_x = get_memory(x); + auto data_res = get_memory(res); + oneapi::mkl::blas::column_major::iamin(q, n, data_x, incx, data_res, + oneapi::mkl::index_base::one); #endif } @@ -636,6 +679,31 @@ inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, batch_size DPCT_COMPUTE_MODE_ARG); } +template +inline void syherk_impl(sycl::queue &q, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, int n, int k, + const void *alpha, const void *a, int lda, + const void *beta, void *c, + int ldc DPCT_COMPUTE_MODE_PARAM) { + auto data_a = get_memory(a); + auto data_c = get_memory(c); + if constexpr (is_hermitian) { + auto alpha_value = dpct::get_value( + reinterpret_cast(alpha), q); + auto beta_value = dpct::get_value( + reinterpret_cast(beta), q); + oneapi::mkl::blas::column_major::herk(q, uplo, trans, n, k, alpha_value, + data_a, lda, beta_value, data_c, + ldc DPCT_COMPUTE_MODE_ARG); + } else { + T alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + T beta_value = dpct::get_value(reinterpret_cast(beta), q); + oneapi::mkl::blas::column_major::syrk(q, uplo, trans, n, k, alpha_value, + data_a, lda, beta_value, data_c, + ldc DPCT_COMPUTE_MODE_ARG); + } +} + template inline void rk_impl(sycl::queue &q, oneapi::mkl::uplo uplo, oneapi::mkl::transpose trans, int n, int k, const T *alpha, @@ -1295,245 +1363,6 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[], #endif } -/// Computes the Euclidean norm of a vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, void *result, library_data_t result_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, result_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::nrm2_impl(q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::nrm2_impl(q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::nrm2_impl, float>( - q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::nrm2_impl, double>( - q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::nrm2_impl( - q, n, x, incx, result); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Computes the dot product of two vectors. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in] y Input vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void dot(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, const void *y, library_data_t y_type, int incy, - void *result, library_data_t result_type) { - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); -} - -/// Computes the dot product of two vectors, conjugating the first vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in] y Input vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void dotc(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, const void *y, library_data_t y_type, int incy, - void *result, library_data_t result_type) { - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); -} - -/// Computes the product of a vector by a scalar. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] alpha The scale factor alpha. -/// \param [in] alpha_type The data type of alpha. -/// \param [in, out] x Input/Output vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -inline void scal(sycl::queue &q, int n, const void *alpha, - library_data_t alpha_type, void *x, library_data_t x_type, - int incx) { - std::uint64_t key = detail::get_type_combination_id(x_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float): { - detail::scal_impl(q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::real_double): { - detail::scal_impl(q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float): { - detail::scal_impl, std::complex>(q, n, alpha, - x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double): { - detail::scal_impl, std::complex>( - q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::real_half): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - sycl::half alaph_half(alpha_value); - detail::scal_impl(q, n, &alaph_half, x, incx); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Computes a vector-scalar product and adds the result to a vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] alpha The scale factor alpha. -/// \param [in] alpha_type The data type of alpha. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in, out] y Input/Output vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -inline void axpy(sycl::queue &q, int n, const void *alpha, - library_data_t alpha_type, const void *x, library_data_t x_type, - int incx, void *y, library_data_t y_type, int incy) { - std::uint64_t key = detail::get_type_combination_id(x_type, alpha_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::axpy_impl, std::complex>( - q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::axpy_impl, std::complex>( - q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_float): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - sycl::half alaph_half(alpha_value); - detail::axpy_impl(q, n, &alaph_half, x, incx, y, incy); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Performs rotation of points in the plane. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in, out] x Input/Output vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in, out] y Input/Output vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [in] c Scaling factor. -/// \param [in] s Scaling factor. -/// \param [in] cs_type Data type of the scaling factors. -inline void rot(sycl::queue &q, int n, void *x, library_data_t x_type, - int incx, void *y, library_data_t y_type, int incy, - const void *c, const void *s, library_data_t cs_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::rot_impl, float, float>(q, n, x, incx, y, incy, c, - s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::rot_impl, double, double>(q, n, x, incx, y, incy, c, - s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::rot_impl, float, std::complex>(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::rot_impl, double, std::complex>(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_bfloat16, - library_data_t::real_bfloat16): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - namespace blas { namespace detail { inline library_data_t compute_type_to_library_data_t(compute_type ct) { @@ -2106,6 +1935,72 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans, } } +/// Performs a symmetric/hermitian rank-k update. +/// \tparam is_hermitian True means current matrix is hermitian. +/// \param [in] desc_ptr Descriptor. +/// \param [in] uplo Specifies whether matrix c is upper or lower triangular. +/// \param [in] trans Specifies op(a), the transposition operation applied to +/// matrix a. +/// \param [in] n Number of rows and columns of matrix c. +/// \param [in] k Number of columns of matrix op(a). +/// \param [in] alpha Scaling factor for the rank-k update. +/// \param [in] a Input matrix a. +/// \param [in] a_type Data type of the matrix a. +/// \param [in] lda Leading dimension of the matrix a. +/// \param [in] beta Scaling factor for the rank-k update. +/// \param [in, out] c Input/Output matrix c. +/// \param [in] c_type Data type of the matrix c. +/// \param [in] ldc Leading dimension of the matrix c. +/// \param [in] ct Compute type. +template +inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, + const void *alpha, const void *a, library_data_t a_type, + std::int64_t lda, const void *beta, void *c, + library_data_t c_type, std::int64_t ldc, + std::variant ct) { + sycl::queue q = desc_ptr->get_queue(); +#ifdef __INTEL_MKL__ + oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset; + if (auto ct_p = std::get_if(&ct)) { + cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(), + a_type == library_data_t::complex_float || + a_type == library_data_t::complex_double); + } else { + cm = deduce_compute_mode(std::nullopt, desc_ptr->get_math_mode(), + a_type == library_data_t::complex_float || + a_type == library_data_t::complex_double); + } +#endif + std::uint64_t key = dpct::detail::get_type_combination_id(a_type, c_type); + if (!is_hermitian && + dpct::detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float) == key) { + dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, a, lda, + beta, c, ldc DPCT_COMPUTE_MODE_ARG); + } else if (!is_hermitian && dpct::detail::get_type_combination_id( + library_data_t::real_double, + library_data_t::real_double) == key) { + dpct::detail::syherk_impl(q, uplo, trans, n, k, alpha, a, + lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); + } else if (dpct::detail::get_type_combination_id( + library_data_t::complex_float, + library_data_t::complex_float) == key) { + dpct::detail::syherk_impl>( + q, uplo, trans, n, k, alpha, a, lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); + } else if (dpct::detail::get_type_combination_id( + library_data_t::complex_double, + library_data_t::complex_double) == key) { + dpct::detail::syherk_impl>( + q, uplo, trans, n, k, alpha, a, lda, beta, c, + ldc DPCT_COMPUTE_MODE_ARG); + } else { + throw std::runtime_error("the combination of data type is unsupported"); + } +} + /// This routines perform a special rank-k update of a symmetric matrix C by /// general matrices A and B. /// \param [in] desc_ptr Descriptor. @@ -2283,19 +2178,548 @@ inline void trmm(descriptor_ptr desc_ptr, oneapi::mkl::side left_right, data_c, ldc DPCT_COMPUTE_MODE_ARG); } -/// Finds the least squares solutions for a batch of overdetermined linear -/// systems. Uses the QR factorization to solve a grouped batch of linear -/// systems with full rank matrices. +/// Computes the Euclidean norm of a vector. /// \param [in] desc_ptr Descriptor. -/// \param [in] trans Operation applied to \p a. -/// \param [in] m The number of rows of \p a. -/// \param [in] n The number of columns of \p a. -/// \param [in] nrhs The number of columns of \p b. -/// \param [in, out] a Array of pointers to matrices. -/// \param [in] lda The leading dimension of \p a. -/// \param [in, out] b Array of pointers to matrices. -/// \param [in] ldb The leading dimension of \p b. -/// \param [out] info Set to 0 if no error. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void nrm2(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, result_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::nrm2_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::nrm2_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + ::dpct::detail::nrm2_impl, float>(q, n, x, incx, + result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + ::dpct::detail::nrm2_impl, double>(q, n, x, incx, + result); + break; + } +#ifdef __INTEL_MKL__ + case ::dpct::detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + ::dpct::detail::nrm2_impl(q, n, x, incx, result); + break; + } +#endif + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes the dot product of two vectors. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void dot(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, const void *y, + library_data_t y_type, std::int64_t incy, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + ::dpct::detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); +} + +/// Computes the dot product of two vectors, conjugating the first vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void dotc(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, const void *y, + library_data_t y_type, std::int64_t incy, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + ::dpct::detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); +} + +/// Computes the product of a vector by a scalar. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +inline void scal(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, + library_data_t alpha_type, void *x, library_data_t x_type, + std::int64_t incx) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float): { + ::dpct::detail::scal_impl(q, n, alpha, x, incx); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double): { + ::dpct::detail::scal_impl(q, n, alpha, x, incx); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float): { + ::dpct::detail::scal_impl, std::complex>( + q, n, alpha, x, incx); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double): { + ::dpct::detail::scal_impl, std::complex>( + q, n, alpha, x, incx); + break; + } +#ifdef __INTEL_MKL__ + case ::dpct::detail::get_type_combination_id(library_data_t::real_half): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + sycl::half alaph_half(alpha_value); + ::dpct::detail::scal_impl(q, n, &alaph_half, x, + incx); + break; + } +#endif + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes a vector-scalar product and adds the result to a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +inline void axpy(descriptor_ptr desc_ptr, std::int64_t n, const void *alpha, + library_data_t alpha_type, const void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, alpha_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::axpy_impl(q, n, alpha, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::axpy_impl(q, n, alpha, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::axpy_impl, std::complex>( + q, n, alpha, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::axpy_impl, std::complex>( + q, n, alpha, x, incx, y, incy); + break; + } +#ifdef __INTEL_MKL__ + case ::dpct::detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + sycl::half alaph_half(alpha_value); + ::dpct::detail::axpy_impl(q, n, &alaph_half, x, + incx, y, incy); + break; + } +#endif + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Performs rotation of points in the plane. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [in] c Scaling factor. +/// \param [in] s Scaling factor. +/// \param [in] cs_type Data type of the scaling factors. +inline void rot(descriptor_ptr desc_ptr, std::int64_t n, void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy, const void *c, + const void *s, library_data_t cs_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type, cs_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::rot_impl(q, n, x, incx, y, incy, c, + s); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + ::dpct::detail::rot_impl, float, float>(q, n, x, incx, + y, incy, c, s); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + ::dpct::detail::rot_impl, double, double>( + q, n, x, incx, y, incy, c, s); + break; + } +#ifdef __INTEL_MKL__ + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::rot_impl, float, std::complex>( + q, n, x, incx, y, incy, c, s); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::rot_impl, double, + std::complex>(q, n, x, incx, y, incy, c, + s); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + ::dpct::detail::rot_impl(q, n, x, incx, + y, incy, c, s); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_bfloat16, + library_data_t::real_bfloat16): { + ::dpct::detail::rot_impl(q, n, x, incx, y, incy, c, + s); + break; + } +#endif + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Performs modified Givens rotation of points in the plane. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [in] param Array of 5 parameters. +/// \param [in] param_type Data type of \p param. +inline void rotm(descriptor_ptr desc_ptr, std::int64_t n, void *x, + library_data_t x_type, int64_t incx, void *y, + library_data_t y_type, int64_t incy, const void *param, + library_data_t param_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, param_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::rotm_impl(q, n, x, incx, y, incy, + param); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::rotm_impl(q, n, x, incx, y, incy, + param); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Copies a vector to another vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] y Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +inline void copy(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type, y_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::copy_impl(q, n, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::copy_impl(q, n, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::copy_impl, std::complex>( + q, n, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::copy_impl, std::complex>( + q, n, x, incx, y, incy); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Swaps a vector with another vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +inline void swap(descriptor_ptr desc_ptr, std::int64_t n, void *x, + library_data_t x_type, std::int64_t incx, void *y, + library_data_t y_type, std::int64_t incy) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type, y_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::swap_impl(q, n, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::swap_impl(q, n, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + ::dpct::detail::swap_impl, std::complex>( + q, n, x, incx, y, incy); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double): { + ::dpct::detail::swap_impl, std::complex>( + q, n, x, incx, y, incy); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes the sum of magnitudes of the vector elements. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The scalar result. +/// \param [in] result_type Data type of \p result. +inline void asum(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, void *result, + library_data_t result_type) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = + ::dpct::detail::get_type_combination_id(x_type, result_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + ::dpct::detail::asum_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + ::dpct::detail::asum_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + ::dpct::detail::asum_impl, float>(q, n, x, incx, + result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + ::dpct::detail::asum_impl, double>(q, n, x, incx, + result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Finds the index of the element with the largest absolute value in a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the maximal element. +inline void iamax(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, + std::int64_t *result) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float): { + ::dpct::detail::iamax_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double): { + ::dpct::detail::iamax_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float): { + ::dpct::detail::iamax_impl>(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double): { + ::dpct::detail::iamax_impl>(q, n, x, incx, result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Finds the index of the element with the smallest absolute value in a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the minimum element. +inline void iamin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, + std::int64_t *result) { + sycl::queue q = desc_ptr->get_queue(); + std::uint64_t key = ::dpct::detail::get_type_combination_id(x_type); + switch (key) { + case ::dpct::detail::get_type_combination_id(library_data_t::real_float): { + ::dpct::detail::iamin_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::real_double): { + ::dpct::detail::iamin_impl(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id(library_data_t::complex_float): { + ::dpct::detail::iamin_impl>(q, n, x, incx, result); + break; + } + case ::dpct::detail::get_type_combination_id( + library_data_t::complex_double): { + ::dpct::detail::iamin_impl>(q, n, x, incx, result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Finds the index of the element with the largest absolute value in a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the maximal element. +inline void iamax(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, int *result) { + dpct::blas::wrapper_int_to_int64_out wrapper(desc_ptr->get_queue(), result); + iamax(desc_ptr, n, x, x_type, incx, wrapper.get_ptr()); +} + +/// Finds the index of the element with the smallest absolute value in a vector. +/// \param [in] desc_ptr Descriptor. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The index of the minimum element. +inline void iamin(descriptor_ptr desc_ptr, std::int64_t n, const void *x, + library_data_t x_type, std::int64_t incx, int *result) { + dpct::blas::wrapper_int_to_int64_out wrapper(desc_ptr->get_queue(), result); + iamin(desc_ptr, n, x, x_type, incx, wrapper.get_ptr()); +} + +/// Finds the least squares solutions for a batch of overdetermined linear +/// systems. Uses the QR factorization to solve a grouped batch of linear +/// systems with full rank matrices. +/// \param [in] desc_ptr Descriptor. +/// \param [in] trans Operation applied to \p a. +/// \param [in] m The number of rows of \p a. +/// \param [in] n The number of columns of \p a. +/// \param [in] nrhs The number of columns of \p b. +/// \param [in, out] a Array of pointers to matrices. +/// \param [in] lda The leading dimension of \p a. +/// \param [in, out] b Array of pointers to matrices. +/// \param [in] ldb The leading dimension of \p b. +/// \param [out] info Set to 0 if no error. /// \param [out] dev_info Optional. If it is not NULL : dev_info[i]==0 means the /// i-th problem is successful; dev_info[i]!=0 means dev_info[i] is the first /// zero diagonal element of the i-th \p a . @@ -2583,6 +3007,119 @@ trmm(sycl::queue &q, oneapi::mkl::side left_right, blas::trmm(&desc, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, ldb, c, ldc); } + +/// Computes the Euclidean norm of a vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +[[deprecated("Please use dpct::blas::nrm2(...) instead.")]] inline void +nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, + void *result, library_data_t result_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::nrm2(&desc, n, x, x_type, incx, result, result_type); +} + +/// Computes the dot product of two vectors. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +[[deprecated("Please use dpct::blas::dot(...) instead.")]] inline void +dot(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, + const void *y, library_data_t y_type, int incy, void *result, + library_data_t result_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::dot(&desc, n, x, x_type, incx, y, y_type, incy, result, result_type); +} + +/// Computes the dot product of two vectors, conjugating the first vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +[[deprecated("Please use dpct::blas::dotc(...) instead.")]] inline void +dotc(sycl::queue &q, int n, const void *x, library_data_t x_type, int incx, + const void *y, library_data_t y_type, int incy, void *result, + library_data_t result_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::dotc(&desc, n, x, x_type, incx, y, y_type, incy, result, result_type); +} + +/// Computes the product of a vector by a scalar. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +[[deprecated("Please use dpct::blas::scal(...) instead.")]] inline void +scal(sycl::queue &q, int n, const void *alpha, library_data_t alpha_type, + void *x, library_data_t x_type, int incx) { + blas::descriptor desc; + desc.set_queue(&q); + blas::scal(&desc, n, alpha, alpha_type, x, x_type, incx); +} + +/// Computes a vector-scalar product and adds the result to a vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +[[deprecated("Please use dpct::blas::axpy(...) instead.")]] inline void +axpy(sycl::queue &q, int n, const void *alpha, library_data_t alpha_type, + const void *x, library_data_t x_type, int incx, void *y, + library_data_t y_type, int incy) { + blas::descriptor desc; + desc.set_queue(&q); + blas::axpy(&desc, n, alpha, alpha_type, x, x_type, incx, y, y_type, incy); +} + +/// Performs rotation of points in the plane. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [in] c Scaling factor. +/// \param [in] s Scaling factor. +/// \param [in] cs_type Data type of the scaling factors. +[[deprecated("Please use dpct::blas::rot(...) instead.")]] inline void +rot(sycl::queue &q, int n, void *x, library_data_t x_type, int incx, void *y, + library_data_t y_type, int incy, const void *c, const void *s, + library_data_t cs_type) { + blas::descriptor desc; + desc.set_queue(&q); + blas::rot(&desc, n, x, x_type, incx, y, y_type, incy, c, s, cs_type); +} } // namespace dpct #undef DPCT_COMPUTE_MODE_ARG #undef DPCT_COMPUTE_MODE_PARAM diff --git a/clang/test/dpct/cublas-usm-11.cu b/clang/test/dpct/cublas-usm-11.cu index 11522b38c4f3..c0273881e833 100644 --- a/clang/test/dpct/cublas-usm-11.cu +++ b/clang/test/dpct/cublas-usm-11.cu @@ -13,12 +13,12 @@ void foo1() { const void **a_array; const void **b_array; void **c_array; - //CHECK:dpct::nrm2(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); - //CHECK-NEXT:dpct::dot(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); - //CHECK-NEXT:dpct::dotc(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); - //CHECK-NEXT:dpct::scal(handle->get_queue(), 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1); - //CHECK-NEXT:dpct::axpy(handle->get_queue(), 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1); - //CHECK-NEXT:dpct::rot(handle->get_queue(), 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, cos, sin, dpct::library_data_t::real_float); + //CHECK:dpct::blas::nrm2(handle, 4, x, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); + //CHECK-NEXT:dpct::blas::dot(handle, 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); + //CHECK-NEXT:dpct::blas::dotc(handle, 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, res, dpct::library_data_t::real_float); + //CHECK-NEXT:dpct::blas::scal(handle, 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1); + //CHECK-NEXT:dpct::blas::axpy(handle, 4, alpha, dpct::library_data_t::real_float, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1); + //CHECK-NEXT:dpct::blas::rot(handle, 4, x, dpct::library_data_t::real_float, 1, y, dpct::library_data_t::real_float, 1, cos, sin, dpct::library_data_t::real_float); //CHECK-NEXT:dpct::blas::gemm(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, b, dpct::library_data_t::real_half, 4, beta, c, dpct::library_data_t::real_half, 4, dpct::compute_type::f16); //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a_array, dpct::library_data_t::real_half, 4, b_array, dpct::library_data_t::real_half, 4, beta, c_array, dpct::library_data_t::real_half, 4, 2, dpct::compute_type::f16); //CHECK-NEXT:dpct::blas::gemm_batch(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, 4, 4, 4, alpha, a, dpct::library_data_t::real_half, 4, 16, b, dpct::library_data_t::real_half, 4, 16, beta, c, dpct::library_data_t::real_half, 4, 16, 2, dpct::compute_type::f16); @@ -53,3 +53,46 @@ void foo3() { cublasGetMathMode(handle, &Mathmode); cublasSetMathMode(handle, Mathmode); } + +void foo4() { + cublasHandle_t handle; + int m, n, k; + void *x, *y; + int incx, incy; + void *res; + int *idx; + void *param; + // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::iamax(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); + cublasCopyEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasSwapEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasIamaxEx(handle, n, x, CUDA_R_32F, incx, idx); + cublasIaminEx(handle, n, x, CUDA_R_32F, incx, idx); + cublasAsumEx(handle, n, x, CUDA_R_32F, incx, res, CUDA_R_32F, CUDA_R_32F); + cublasRotmEx(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy, param, CUDA_R_32F, CUDA_R_32F); + + cublasOperation_t transa; + cublasOperation_t transb; + cuComplex *alpha, *beta; + void *A, *B, *C; + int lda, ldb, ldc; + cudaDataType_t a_type, b_type, c_type; + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, dpct::library_data_t::complex_float); + cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc); + + cublasFillMode_t uplo; + cublasOperation_t trans; + float *alpha_s, *beta_s; + // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float); + cublasCsyrkEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); + cublasCsyrk3mEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc); + cublasCherkEx(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc); + cublasCherk3mEx(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc); +} diff --git a/clang/test/dpct/cublas_64.cu b/clang/test/dpct/cublas_64.cu index 8050107e154b..0ed93e69a664 100644 --- a/clang/test/dpct/cublas_64.cu +++ b/clang/test/dpct/cublas_64.cu @@ -475,6 +475,67 @@ void foo() { status = cublasCherkx_64(handle, uplo, transa, n, k, alpha_c, A_c, lda, B_c, ldb, beta_s, C_c, ldc); status = cublasZherkx_64(handle, uplo, transa, n, k, alpha_z, A_z, lda, B_z, ldb, beta_d, C_z, ldc); + cudaDataType type_x; + cudaDataType type_y; + cudaDataType type_res; + cudaDataType type_exec; + cudaDataType type_alpha; + cudaDataType type_cs; + void *res; + void *x; + void *y; + void *alpha; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::nrm2(handle, n, x, type_x, incx, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dot(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dotc(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::scal(handle, n, alpha, type_alpha, x, type_x, incx)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::axpy(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::rot(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs)); + status = cublasNrm2Ex_64(handle, n, x, type_x, incx, res, type_res, type_exec); + status = cublasDotEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasDotcEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasScalEx_64(handle, n, alpha, type_alpha, x, type_x, incx, type_exec); + status = cublasAxpyEx_64(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy, type_exec); + status = cublasRotEx_64(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs, type_exec); + + void **a_array; + void **b_array; + void **c_array; + void *aa; + void *bb; + void *cc; + cudaDataType type_a; + cudaDataType type_b; + cudaDataType type_c; + void *beta; + cublasGemmAlgo_t algo; + int64_t batch; + int64_t stride_a; + int64_t stride_b; + int64_t stride_c; + cublasComputeType_t type_compute; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute)); + status = cublasGemmStridedBatchedEx_64(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute, algo); + + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); + cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); + cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasCgemm3mEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasGemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute, algo); + + cublasOperation_t trans; + // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); + cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); + cublasCherk3mEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); + // CHECK: status = DPCT_CHECK_ERROR(oneapi::mkl::blas::column_major::tbsv(handle->get_queue(), uplo, transa, diag, n, k, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer(A_s)), lda, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer(y_s)), incx)); // CHECK-NEXT: status = DPCT_CHECK_ERROR(oneapi::mkl::blas::column_major::tbsv(handle->get_queue(), uplo, transa, diag, n, k, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer(A_d)), lda, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer(y_d)), incx)); // CHECK-NEXT: status = DPCT_CHECK_ERROR(oneapi::mkl::blas::column_major::tbsv(handle->get_queue(), uplo, transa, diag, n, k, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer>(A_c)), lda, dpct::rvalue_ref_to_lvalue_ref(dpct::get_buffer>(y_c)), incx)); @@ -579,3 +640,25 @@ void foo() { status = cublasChpr2_64(handle, uplo, n, alpha_c, x_c, incx, y_c, incy, C_c); status = cublasZhpr2_64(handle, uplo, n, alpha_z, x_z, incx, y_z, incy, C_z); } + +void foo2() { + cublasHandle_t handle; + int n; + void *x, *y; + int incx, incy; + void *res; + int64_t *idx; + void *param; + // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::iamax(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); + cublasCopyEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasSwapEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasIamaxEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasIaminEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasAsumEx_64(handle, n, x, CUDA_R_32F, incx, res, CUDA_R_32F, CUDA_R_32F); + cublasRotmEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy, param, CUDA_R_32F, CUDA_R_32F); +} diff --git a/clang/test/dpct/cublas_64_usm.cu b/clang/test/dpct/cublas_64_usm.cu index 19a30481b87a..e90621a3847f 100644 --- a/clang/test/dpct/cublas_64_usm.cu +++ b/clang/test/dpct/cublas_64_usm.cu @@ -475,6 +475,69 @@ void foo() { status = cublasCherkx_64(handle, uplo, transa, n, k, alpha_c, A_c, lda, B_c, ldb, beta_s, C_c, ldc); status = cublasZherkx_64(handle, uplo, transa, n, k, alpha_z, A_z, lda, B_z, ldb, beta_d, C_z, ldc); + cudaDataType type_x; + cudaDataType type_y; + cudaDataType type_res; + cudaDataType type_exec; + cudaDataType type_alpha; + cudaDataType type_cs; + void *res; + void *x; + void *y; + void *alpha; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::nrm2(handle, n, x, type_x, incx, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dot(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::dotc(handle, n, x, type_x, incx, y, type_y, incy, res, type_res)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::scal(handle, n, alpha, type_alpha, x, type_x, incx)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::axpy(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::rot(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs)); + status = cublasNrm2Ex_64(handle, n, x, type_x, incx, res, type_res, type_exec); + status = cublasDotEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasDotcEx_64(handle, n, x, type_x, incx, y, type_y, incy, res, type_res, type_exec); + status = cublasScalEx_64(handle, n, alpha, type_alpha, x, type_x, incx, type_exec); + status = cublasAxpyEx_64(handle, n, alpha, type_alpha, x, type_x, incx, y, type_y, incy, type_exec); + status = cublasRotEx_64(handle, n, x, type_x, incx, y, type_y, incy, c, s, type_cs, type_exec); + + void **a_array; + void **b_array; + void **c_array; + void *aa; + void *bb; + void *cc; + cudaDataType type_a; + cudaDataType type_b; + cudaDataType type_c; + void *beta; + cublasGemmAlgo_t algo; + int64_t batch; + int64_t stride_a; + int64_t stride_b; + int64_t stride_c; + cublasComputeType_t type_compute; + // CHECK: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, const_cast(a_array), type_a, lda, const_cast(b_array), type_b, ldb, beta, c_array, type_c, ldc, batch, type_compute)); + // CHECK-NEXT: status = DPCT_CHECK_ERROR(dpct::blas::gemm_batch(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute)); + status = cublasGemmBatchedEx_64(handle, transa, transb, m, n, k, alpha, a_array, type_a, lda, b_array, type_b, ldb, beta, c_array, type_c, ldc, batch, type_compute, algo); + status = cublasGemmStridedBatchedEx_64(handle, transa, transb, m, n, k, alpha, aa, type_a, lda, stride_a, bb, type_b, ldb, stride_b, beta, cc, type_c, ldc, stride_c, batch, type_compute, algo); + + // CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute); + cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc); + cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasCgemm3mEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc); + cublasGemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute, algo); + + cublasOperation_t trans; + // CHECK: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); + // CHECK-NEXT: dpct::blas::syherk(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float); + cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc); + cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); + cublasCherk3mEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc); + // CHECK: status = DPCT_CHECK_ERROR(oneapi::mkl::blas::column_major::tbsv(handle->get_queue(), uplo, transa, diag, n, k, A_s, lda, y_s, incx)); // CHECK-NEXT: status = DPCT_CHECK_ERROR(oneapi::mkl::blas::column_major::tbsv(handle->get_queue(), uplo, transa, diag, n, k, A_d, lda, y_d, incx)); // CHECK-NEXT: status = DPCT_CHECK_ERROR(oneapi::mkl::blas::column_major::tbsv(handle->get_queue(), uplo, transa, diag, n, k, (std::complex*)A_c, lda, (std::complex*)y_c, incx)); @@ -579,3 +642,25 @@ void foo() { status = cublasChpr2_64(handle, uplo, n, alpha_c, x_c, incx, y_c, incy, C_c); status = cublasZhpr2_64(handle, uplo, n, alpha_z, x_z, incx, y_z, incy, C_z); } + +void foo2() { + cublasHandle_t handle; + int n; + void *x, *y; + int incx, incy; + void *res; + int64_t *idx; + void *param; + // CHECK: dpct::blas::copy(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::swap(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy); + // CHECK-NEXT: dpct::blas::iamax(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::iamin(handle, n, x, dpct::library_data_t::real_float, incx, idx); + // CHECK-NEXT: dpct::blas::asum(handle, n, x, dpct::library_data_t::real_float, incx, res, dpct::library_data_t::real_float); + // CHECK-NEXT: dpct::blas::rotm(handle, n, x, dpct::library_data_t::real_float, incx, y, dpct::library_data_t::real_float, incy, param, dpct::library_data_t::real_float); + cublasCopyEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasSwapEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy); + cublasIamaxEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasIaminEx_64(handle, n, x, CUDA_R_32F, incx, idx); + cublasAsumEx_64(handle, n, x, CUDA_R_32F, incx, res, CUDA_R_32F, CUDA_R_32F); + cublasRotmEx_64(handle, n, x, CUDA_R_32F, incx, y, CUDA_R_32F, incy, param, CUDA_R_32F, CUDA_R_32F); +} diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu index 487daa3ebb1d..9591816f9b9d 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_10_1.cu @@ -9,4 +9,4 @@ // cublasRotEx-NEXT: s /*const void **/, cstype /*cudaDataType*/, // cublasRotEx-NEXT: computetype /*cudaDataType*/); // cublasRotEx-NEXT: Is migrated to: -// cublasRotEx-NEXT: dpct::rot(handle->get_queue(), n, x, xtype, incx, y, ytype, incy, c, s, cstype); +// cublasRotEx-NEXT: dpct::blas::rot(handle, n, x, xtype, incx, y, ytype, incy, c, s, cstype); diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu index 6d7b54ae38e4..1054027cc400 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part1.cu @@ -86,7 +86,7 @@ // cublasAxpyEx-NEXT: ytype /*cudaDataType*/, incy /*int*/, // cublasAxpyEx-NEXT: computetype /*cudaDataType*/); // cublasAxpyEx-NEXT: Is migrated to: -// cublasAxpyEx-NEXT: dpct::axpy(handle->get_queue(), n, alpha, alphatype, x, xtype, incx, y, ytype, incy); +// cublasAxpyEx-NEXT: dpct::blas::axpy(handle, n, alpha, alphatype, x, xtype, incx, y, ytype, incy); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasDtpmv | FileCheck %s -check-prefix=cublasDtpmv // cublasDtpmv: CUDA API: diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu index 85bbb574d3ab..b646ba45df10 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part2.cu @@ -5,7 +5,7 @@ // cublasDotEx-NEXT: ytype /*cudaDataType*/, incy /*int*/, res /*void **/, // cublasDotEx-NEXT: restype /*cudaDataType*/, computetype /*cudaDataType*/); // cublasDotEx-NEXT: Is migrated to: -// cublasDotEx-NEXT: dpct::dot(handle->get_queue(), n, x, xtype, incx, y, ytype, incy, res, restype); +// cublasDotEx-NEXT: dpct::blas::dot(handle, n, x, xtype, incx, y, ytype, incy, res, restype); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasDtbmv | FileCheck %s -check-prefix=cublasDtbmv // cublasDtbmv: CUDA API: diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu index bbe98f8ce5f2..21f1ee6ac93a 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part3.cu @@ -15,7 +15,7 @@ // cublasDotcEx-NEXT: ytype /*cudaDataType*/, incy /*int*/, res /*void **/, // cublasDotcEx-NEXT: restype /*cudaDataType*/, computetype /*cudaDataType*/); // cublasDotcEx-NEXT: Is migrated to: -// cublasDotcEx-NEXT: dpct::dotc(handle->get_queue(), n, x, xtype, incx, y, ytype, incy, res, restype); +// cublasDotcEx-NEXT: dpct::blas::dotc(handle, n, x, xtype, incx, y, ytype, incy, res, restype); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasSsyr2 | FileCheck %s -check-prefix=cublasSsyr2 // cublasSsyr2: CUDA API: diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu index 6bad86f59749..5fdd590d8d79 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part4.cu @@ -236,4 +236,4 @@ // cublasNrm2Ex-NEXT: xtype /*cudaDataType*/, incx /*int*/, res /*void **/, // cublasNrm2Ex-NEXT: restype /*cudaDataType*/, computetype /*cudaDataType*/); // cublasNrm2Ex-NEXT: Is migrated to: -// cublasNrm2Ex-NEXT: dpct::nrm2(handle->get_queue(), n, x, xtype, incx, res, restype); +// cublasNrm2Ex-NEXT: dpct::blas::nrm2(handle, n, x, xtype, incx, res, restype); diff --git a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu index 04027172cb09..ddf0daaffbc6 100644 --- a/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu +++ b/clang/test/dpct/query_api_mapping/cuBLAS/blas_part8.cu @@ -16,7 +16,7 @@ // cublasScalEx-NEXT: alphatype /*cudaDataType*/, x /*void **/, xtype /*cudaDataType*/, // cublasScalEx-NEXT: incx /*int*/, computetype /*cudaDataType*/); // cublasScalEx-NEXT: Is migrated to: -// cublasScalEx-NEXT: dpct::scal(handle->get_queue(), n, alpha, alphatype, x, xtype, incx); +// cublasScalEx-NEXT: dpct::blas::scal(handle, n, alpha, alphatype, x, xtype, incx); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cublasDger | FileCheck %s -check-prefix=cublasDger // cublasDger: CUDA API: