Skip to content

Commit

Permalink
Fix helper
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
  • Loading branch information
zhiweij1 committed Sep 19, 2024
1 parent 30a1773 commit 5c8ff1d
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions clang/runtime/dpct-rt/include/dpct/blas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,18 +718,23 @@ inline void syherk_impl(sycl::queue &q, oneapi::mkl::uplo uplo,
const void *alpha, const void *a, int lda,
const void *beta, void *c,
int ldc DPCT_COMPUTE_MODE_PARAM) {
T alpha_value = dpct::get_value(reinterpret_cast<const T *>(alpha), q);
T beta_value = dpct::get_value(reinterpret_cast<const T *>(beta), q);
auto data_a = get_memory<const T>(a);
auto data_c = get_memory<T>(c);
if constexpr (is_hermitian)
if constexpr (is_hermitian) {
auto alpha_value = dpct::get_value(
reinterpret_cast<const typename T::value_type *>(alpha), q);
auto beta_value = dpct::get_value(
reinterpret_cast<const typename T::value_type *>(beta), q);
oneapi::mkl::blas::column_major::herk(q, uplo, trans, n, k, alpha_value,
data_a, lda, beta_value, data_c,
ldc DPCT_COMPUTE_MODE_ARG);
else
} else {
T alpha_value = dpct::get_value(reinterpret_cast<const T *>(alpha), q);
T beta_value = dpct::get_value(reinterpret_cast<const T *>(beta), q);
oneapi::mkl::blas::column_major::syrk(q, uplo, trans, n, k, alpha_value,
data_a, lda, beta_value, data_c,
ldc DPCT_COMPUTE_MODE_ARG);
}
}

template <bool is_hermitian, class T, class Tbeta>
Expand Down Expand Up @@ -1472,21 +1477,24 @@ deduce_compute_mode(std::optional<compute_type> ct, math_mode mm,
/// \param [in] c_type Data type of the matrix C.
/// \param [in] ldc Leading dimension of C.
/// \param [in] ct Compute type.
inline void gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n,
std::int64_t k, const void *alpha, const void *a,
library_data_t a_type, std::int64_t lda, const void *b,
library_data_t b_type, std::int64_t ldb, const void *beta,
void *c, library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t> ct) {
inline void
gemm(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
oneapi::mkl::transpose b_trans, std::int64_t m, std::int64_t n,
std::int64_t k, const void *alpha, const void *a, library_data_t a_type,
std::int64_t lda, const void *b, library_data_t b_type, std::int64_t ldb,
const void *beta, void *c, library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t, oneapi::mkl::blas::compute_mode>
ct) {
#ifndef __INTEL_MKL__
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
"Project does not support this API.");
#else
sycl::queue q = desc_ptr->get_queue();
oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset;
library_data_t scaling_type;
if (auto ct_p = std::get_if<compute_type>(&ct)) {
if (auto ct_p = std::get_if<oneapi::mkl::blas::compute_mode>(&ct)) {
cm = *ct_p;
} else if (auto ct_p = std::get_if<compute_type>(&ct)) {
cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(),
a_type == library_data_t::complex_float ||
a_type == library_data_t::complex_double);
Expand Down Expand Up @@ -1981,16 +1989,19 @@ inline void gemm_batch(descriptor_ptr desc_ptr, oneapi::mkl::transpose a_trans,
/// \param [in] ldc Leading dimension of the matrix c.
/// \param [in] ct Compute type.
template <bool is_hermitian>
inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo,
oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k,
const void *alpha, const void *a, library_data_t a_type,
std::int64_t lda, const void *beta, void *c,
library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t> ct) {
inline void syherk(
descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo,
oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k,
const void *alpha, const void *a, library_data_t a_type, std::int64_t lda,
const void *beta, void *c, library_data_t c_type, std::int64_t ldc,
std::variant<compute_type, library_data_t, oneapi::mkl::blas::compute_mode>
ct) {
sycl::queue q = desc_ptr->get_queue();
#ifdef __INTEL_MKL__
oneapi::mkl::blas::compute_mode cm = oneapi::mkl::blas::compute_mode::unset;
if (auto ct_p = std::get_if<compute_type>(&ct)) {
if (auto ct_p = std::get_if<oneapi::mkl::blas::compute_mode>(&ct)) {
cm = *ct_p;
} else if (auto ct_p = std::get_if<compute_type>(&ct)) {
cm = deduce_compute_mode(*ct_p, desc_ptr->get_math_mode(),
a_type == library_data_t::complex_float ||
a_type == library_data_t::complex_double);
Expand All @@ -2004,15 +2015,14 @@ inline void syherk(descriptor_ptr desc_ptr, oneapi::mkl::uplo uplo,
if (!is_hermitian &&
dpct::detail::get_type_combination_id(
library_data_t::real_float, library_data_t::real_float) == key) {
dpct::detail::syherk_impl<is_hermitian, float>(q, uplo, trans, n, k, alpha,
a, lda, beta, c,
ldc DPCT_COMPUTE_MODE_ARG);
dpct::detail::syherk_impl<false, float>(q, uplo, trans, n, k, alpha, a, lda,
beta, c, ldc DPCT_COMPUTE_MODE_ARG);
} else if (!is_hermitian && dpct::detail::get_type_combination_id(
library_data_t::real_double,
library_data_t::real_double) == key) {
dpct::detail::syherk_impl<is_hermitian, double>(q, uplo, trans, n, k, alpha,
a, lda, beta, c,
ldc DPCT_COMPUTE_MODE_ARG);
dpct::detail::syherk_impl<false, double>(q, uplo, trans, n, k, alpha, a,
lda, beta, c,
ldc DPCT_COMPUTE_MODE_ARG);
} else if (dpct::detail::get_type_combination_id(
library_data_t::complex_float,
library_data_t::complex_float) == key) {
Expand Down

0 comments on commit 5c8ff1d

Please sign in to comment.