Skip to content

Commit

Permalink
Update1
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
  • Loading branch information
zhiweij1 committed Sep 20, 2024
1 parent 5c8ff1d commit f441cd8
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 30 deletions.
18 changes: 12 additions & 6 deletions clang/lib/DPCT/APINamesCUBLAS.inc
Original file line number Diff line number Diff line change
Expand Up @@ -734,13 +734,15 @@ GEMM_EX(cublasSgemmEx,
MapNames::getLibraryHelperNamespace() + "library_data_t::real_float")
GEMM_EX(cublasCgemmEx,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
GEMM_EX(cublasCgemm3mEx, "oneapi::mkl::blas::compute_mode::complex_3m")
GEMM_EX(cublasCgemm3mEx,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
GEMM_EX(cublasGemmEx, 17)
GEMM_EX(cublasSgemmEx_64,
MapNames::getLibraryHelperNamespace() + "library_data_t::real_float")
GEMM_EX(cublasCgemmEx_64,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
GEMM_EX(cublasCgemm3mEx_64, "oneapi::mkl::blas::compute_mode::complex_3m")
GEMM_EX(cublasCgemm3mEx_64,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
GEMM_EX(cublasGemmEx_64, 17)
#undef GEMM_EX

Expand All @@ -759,16 +761,20 @@ GEMM_EX(cublasGemmEx_64, 17)
ARG(10), ARG(11), ARG(12), ARG(COMPUTE_TYPE)))))
SYHERK(cublasCsyrkEx, false,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
SYHERK(cublasCsyrk3mEx, false, "oneapi::mkl::blas::compute_mode::complex_3m")
SYHERK(cublasCsyrk3mEx, false,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
SYHERK(cublasCherkEx, true,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
SYHERK(cublasCherk3mEx, true, "oneapi::mkl::blas::compute_mode::complex_3m")
SYHERK(cublasCherk3mEx, true,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
SYHERK(cublasCsyrkEx_64, false,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
SYHERK(cublasCsyrk3mEx_64, false, "oneapi::mkl::blas::compute_mode::complex_3m")
SYHERK(cublasCsyrk3mEx_64, false,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
SYHERK(cublasCherkEx_64, true,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
SYHERK(cublasCherk3mEx_64, true, "oneapi::mkl::blas::compute_mode::complex_3m")
SYHERK(cublasCherk3mEx_64, true,
MapNames::getLibraryHelperNamespace() + "library_data_t::complex_float")
#undef SYHERK

#define SYRK(NAME, TYPE, IS_COMPLEX) \
Expand Down
28 changes: 13 additions & 15 deletions clang/runtime/dpct-rt/include/dpct/blas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1477,14 +1477,13 @@ deduce_compute_mode(std::optional<compute_type> ct, math_mode mm,
/// \param [in] c_type Data type of the matrix C.
/// \param [in] ldc Leading dimension of C.
/// \param [in] ct Compute type.
inline void
gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n,
std::int64_t k, const void *alpha, const void *a, library_data_t a_type,
std::int64_t lda, const void *b, library_data_t b_type, std::int64_t ldb,
const void *beta, void *c, library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t, oneapi::mkl::blas::compute_mode>
ct) {
inline void gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n,
std::int64_t k, const void *alpha, const void *a,
library_data_t a_type, std::int64_t lda, const void *b,
library_data_t b_type, std::int64_t ldb, const void *beta,
void *c, library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t> ct) {
#ifndef __INTEL_MKL__
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
"Project does not support this API.");
Expand Down Expand Up @@ -1989,13 +1988,12 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
/// \param [in] ldc Leading dimension of the matrix c.
/// \param [in] ct Compute type.
template <bool is_hermitian>
inline void syherk(
descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo,
oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k,
const void *alpha, const void *a, library_data_t a_type, std::int64_t lda,
const void *beta, void *c, library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t, oneapi::mkl::blas::compute_mode>
ct) {
inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo,
oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k,
const void *alpha, const void *a, library_data_t a_type,
std::int64_t lda, const void *beta, void *c,
library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t> ct) {
sycl::queue q = desc_ptr->get_queue();
#ifdef __INTEL_MKL__
oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset;
Expand Down
6 changes: 3 additions & 3 deletions clang/test/dpct/cublas-usm-11.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ void foo4() {
void *A, *B, *C;
int lda, ldb, ldc;
cudaDataType_t a_type, b_type, c_type;
// CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
// CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc, dpct::library_data_t::complex_float);
cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, a_type, lda, B, b_type, ldb, beta, C, c_type, ldc);

cublasFillMode_t uplo;
cublasOperation_t trans;
float *alpha_s, *beta_s;
// CHECK: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
// CHECK-NEXT: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
cublasCsyrkEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc);
cublasCsyrk3mEx(handle, uplo, trans, n, k, alpha, A, a_type, lda, beta, C, c_type, ldc);
cublasCherkEx(handle, uplo, trans, n, k, alpha_s, A, a_type, lda, beta_s, C, c_type, ldc);
Expand Down
6 changes: 3 additions & 3 deletions clang/test/dpct/cublas_64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ void foo() {

// CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute);
cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc);
cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc);
Expand All @@ -528,9 +528,9 @@ void foo() {

cublasOperation_t trans;
// CHECK: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
// CHECK-NEXT: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc);
cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc);
cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc);
Expand Down
6 changes: 3 additions & 3 deletions clang/test/dpct/cublas_64_usm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ void foo() {

// CHECK: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc, dpct::library_data_t::real_float);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::gemm(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc, type_compute);
cublasSgemmEx_64(handle, transa, transb, m, n, k, alpha_s, A_s, type_a, lda, B_s, type_b, ldb, beta_s, C_s, type_c, ldc);
cublasCgemmEx_64(handle, transa, transb, m, n, k, alpha_c, A_c, type_a, lda, B_c, type_b, ldb, beta_c, C_c, type_c, ldc);
Expand All @@ -530,9 +530,9 @@ void foo() {

cublasOperation_t trans;
// CHECK: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
// CHECK-NEXT: dpct::blas::syherk<false>(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, dpct::library_data_t::complex_float);
// CHECK-NEXT: dpct::blas::syherk<true>(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc, oneapi::mkl::blas::compute_mode::complex_3m);
cublasCsyrkEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc);
cublasCsyrk3mEx_64(handle, uplo, trans, n, k, alpha_c, A_c, type_a, lda, beta_c, C_c, type_c, ldc);
cublasCherkEx_64(handle, uplo, trans, n, k, alpha_s, A_c, type_a, lda, beta_s, C_c, type_c, ldc);
Expand Down

0 comments on commit f441cd8

Please sign in to comment.