Skip to content

Commit

Permalink
Interface more LAPACK routines
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 18, 2023
1 parent db9359e commit a1db608
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 10 deletions.
184 changes: 179 additions & 5 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,180 @@ class trsmBatchInfo {
}
};

extern "C" void onemklSormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, float *a, int64_t lda, float *tau,
float *c, int64_t ldc) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size<float>(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc);
auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
auto status = oneapi::mkl::lapack::ormqr(device_queue->val, convert(side), convert(trans), m, n, k, a, lda, tau, c, ldc, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, double *a, int64_t lda, double *tau,
double *c, int64_t ldc) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size<double>(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc);
auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
auto status = oneapi::mkl::lapack::ormqr(device_queue->val, convert(side), convert(trans), m, n, k, a, lda, tau, c, ldc, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, float _Complex *a, int64_t lda, float _Complex *tau,
float _Complex *c, int64_t ldc) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size<std::complex<float>>(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc);
auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
auto status = oneapi::mkl::lapack::unmqr(device_queue->val, convert(side), convert(trans), m, n, k, reinterpret_cast<std::complex<float> *>(a), lda,
reinterpret_cast<std::complex<float> *>(tau), reinterpret_cast<std::complex<float> *>(c), ldc,
scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, double _Complex *a, int64_t lda, double _Complex *tau,
double _Complex *c, int64_t ldc) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size<std::complex<double>>(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc);
auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
auto status = oneapi::mkl::lapack::unmqr(device_queue->val, convert(side), convert(trans), m, n, k, reinterpret_cast<std::complex<double> *>(a), lda,
reinterpret_cast<std::complex<double> *>(tau), reinterpret_cast<std::complex<double> *>(c), ldc,
scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float *a,
int64_t lda, float *tau) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size<float>(device_queue->val, m, n, k, lda);
auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
auto status = oneapi::mkl::lapack::orgqr(device_queue->val, m, n, k, a, lda, tau, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double *a,
int64_t lda, double *tau) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size<double>(device_queue->val, m, n, k, lda);
auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
auto status = oneapi::mkl::lapack::orgqr(device_queue->val, m, n, k, a, lda, tau, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float _Complex *a,
int64_t lda, float _Complex *tau) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size<std::complex<float>>(device_queue->val, m, n, k, lda);
auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
auto status = oneapi::mkl::lapack::ungqr(device_queue->val, m, n, k, reinterpret_cast<std::complex<float> *>(a), lda,
reinterpret_cast<std::complex<float> *>(tau), scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double _Complex *a,
int64_t lda, double _Complex *tau) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size<std::complex<double>>(device_queue->val, m, n, k, lda);
auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
auto status = oneapi::mkl::lapack::ungqr(device_queue->val, m, n, k, reinterpret_cast<std::complex<double> *>(a), lda,
reinterpret_cast<std::complex<double> *>(tau), scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<float>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<double>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<std::complex<float>>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast<std::complex<float> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<std::complex<double>>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast<std::complex<double> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<float>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<double>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<std::complex<float>>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast<std::complex<float> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<std::complex<double>>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast<std::complex<double> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n,
float *a, int64_t lda, float *tau) {
auto device = device_queue->val.get_device();
Expand Down Expand Up @@ -1256,7 +1430,7 @@ extern "C" void onemklZasum(syclQueue_t device_queue, int64_t n,
}

extern "C" void onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t alpha,
const short *x, std::int64_t incx, short *y, int64_t incy) {
const short *x, int64_t incx, short *y, int64_t incy) {
auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n,
sycl::bit_cast<sycl::half>(alpha),
reinterpret_cast<const sycl::half *>(x),
Expand All @@ -1265,29 +1439,29 @@ extern "C" void onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t alpha,
}

extern "C" void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha,
const float *x, std::int64_t incx, float *y, int64_t incy) {
const float *x, int64_t incx, float *y, int64_t incy) {
auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x,
incx, y, incy);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha,
const double *x, std::int64_t incx, double *y, int64_t incy) {
const double *x, int64_t incx, double *y, int64_t incy) {
auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x,
incx, y, incy);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha,
const float _Complex *x, std::int64_t incx, float _Complex *y, int64_t incy) {
const float _Complex *x, int64_t incx, float _Complex *y, int64_t incy) {
auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, static_cast<std::complex<float> >(alpha),
reinterpret_cast<const std::complex<float> *>(x), incx,
reinterpret_cast<std::complex<float> *>(y), incy);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha,
const double _Complex *x, std::int64_t incx, double _Complex *y, int64_t incy) {
const double _Complex *x, int64_t incx, double _Complex *y, int64_t incy) {
auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, static_cast<std::complex<double> >(alpha),
reinterpret_cast<const std::complex<double> *>(x), incx,
reinterpret_cast<std::complex<double> *>(y), incy);
Expand Down
45 changes: 40 additions & 5 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,46 @@ typedef enum {
ONEMKL_SIDE_RIGHT
} onemklSide;

void onemklSormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, float *a, int64_t lda, float *tau,
float *c, int64_t ldc);
void onemklDormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, double *a, int64_t lda, double *tau,
double *c, int64_t ldc);
void onemklCunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, float _Complex *a, int64_t lda, float _Complex *tau,
float _Complex *c, int64_t ldc);
void onemklZunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans,
int64_t m, int64_t n, int64_t k, double _Complex *a, int64_t lda, double _Complex *tau,
double _Complex *c, int64_t ldc);

void onemklSorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float *a,
int64_t lda, float *tau);
void onemklDorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double *a,
int64_t lda, double *tau);
void onemklCungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float _Complex *a,
int64_t lda, float _Complex *tau);
void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double _Complex *a,
int64_t lda, double _Complex *tau);

void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);

void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);

void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n,
float *a, int64_t lda, float *tau);
void onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n,
Expand Down Expand Up @@ -91,11 +131,6 @@ void onemklCgetri(syclQueue_t device_queue, int64_t n,
void onemklZgetri(syclQueue_t device_queue, int64_t n,
double _Complex *a, int64_t lda, int64_t *ipiv);

// XXX: how to expose half in C?
// int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
// onemklTranspose transB, int64_t m, int64_t n, int64_t k,
// half alpha, const half *A, int64_t lda, const half *B,
// int64_t ldb, half beta, half *C, int64_t ldc);
int onemklSgemm(syclQueue_t device_queue, onemklTranspose transA,
onemklTranspose transB, int64_t m, int64_t n, int64_t k,
float alpha, const float *A, int64_t lda, const float *B,
Expand Down

0 comments on commit a1db608

Please sign in to comment.