diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 3f73c3f8..71b76eea 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -193,6 +193,180 @@ class trsmBatchInfo { } }; +extern "C" void onemklSormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, float *a, int64_t lda, float *tau, + float *c, int64_t ldc) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc); + auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); + auto status = oneapi::mkl::lapack::ormqr(device_queue->val, convert(side), convert(trans), m, n, k, a, lda, tau, c, ldc, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklDormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, double *a, int64_t lda, double *tau, + double *c, int64_t ldc) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc); + auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); + auto status = oneapi::mkl::lapack::ormqr(device_queue->val, convert(side), convert(trans), m, n, k, a, lda, tau, c, ldc, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, float _Complex *a, int64_t lda, float _Complex *tau, + float _Complex *c, int64_t ldc) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size>(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::unmqr(device_queue->val, convert(side), convert(trans), m, n, k, reinterpret_cast *>(a), lda, + reinterpret_cast *>(tau), reinterpret_cast *>(c), ldc, + scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, double _Complex *a, int64_t lda, double _Complex *tau, + double _Complex *c, int64_t ldc) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size>(device_queue->val, convert(side), convert(trans), m, n, k, lda, ldc); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::unmqr(device_queue->val, convert(side), convert(trans), m, n, k, reinterpret_cast *>(a), lda, + reinterpret_cast *>(tau), reinterpret_cast *>(c), ldc, + scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklSorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float *a, + int64_t lda, float *tau) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size(device_queue->val, m, n, k, lda); + auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); + auto status = oneapi::mkl::lapack::orgqr(device_queue->val, m, n, k, a, lda, tau, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklDorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double *a, + int64_t lda, double *tau) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size(device_queue->val, m, n, k, lda); + auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); + auto status = oneapi::mkl::lapack::orgqr(device_queue->val, m, n, k, a, lda, tau, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float _Complex *a, + int64_t lda, float _Complex *tau) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size>(device_queue->val, m, n, k, lda); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::ungqr(device_queue->val, m, n, k, reinterpret_cast *>(a), lda, + reinterpret_cast *>(tau), scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double _Complex *a, + int64_t lda, double _Complex *tau) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size>(device_queue->val, m, n, k, lda); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::ungqr(device_queue->val, m, n, k, reinterpret_cast *>(a), lda, + reinterpret_cast *>(tau), scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size>(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size>(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size>(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size>(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + extern "C" void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau) { auto device = device_queue->val.get_device(); @@ -1256,7 +1430,7 @@ extern "C" void onemklZasum(syclQueue_t device_queue, int64_t n, } extern "C" void onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t alpha, - const short *x, std::int64_t incx, short *y, int64_t incy) { + const short *x, int64_t incx, short *y, int64_t incy) { auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, sycl::bit_cast(alpha), reinterpret_cast(x), @@ -1265,21 +1439,21 @@ extern "C" void onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t alpha, } extern "C" void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha, - const float *x, std::int64_t incx, float *y, int64_t incy) { + const float *x, int64_t incx, float *y, int64_t incy) { auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy); __FORCE_MKL_FLUSH__(status); } extern "C" void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha, - const double *x, std::int64_t incx, double *y, int64_t incy) { + const double *x, int64_t incx, double *y, int64_t incy) { auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy); __FORCE_MKL_FLUSH__(status); } extern "C" void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha, - const float _Complex *x, std::int64_t incx, float _Complex *y, int64_t incy) { + const float _Complex *x, int64_t incx, float _Complex *y, int64_t incy) { auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, static_cast >(alpha), reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy); @@ -1287,7 +1461,7 @@ extern "C" void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex } extern "C" void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha, - const double _Complex *x, std::int64_t incx, double _Complex *y, int64_t incy) { + const double _Complex *x, int64_t incx, double _Complex *y, int64_t incy) { auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n, static_cast >(alpha), reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy); diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index b6cc64bc..1d310fe6 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -30,6 +30,46 @@ typedef enum { ONEMKL_SIDE_RIGHT } onemklSide; +void onemklSormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, float *a, int64_t lda, float *tau, + float *c, int64_t ldc); +void onemklDormqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, double *a, int64_t lda, double *tau, + double *c, int64_t ldc); +void onemklCunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, float _Complex *a, int64_t lda, float _Complex *tau, + float _Complex *c, int64_t ldc); +void onemklZunmqr(syclQueue_t device_queue, onemklSide side, onemklTranspose trans, + int64_t m, int64_t n, int64_t k, double _Complex *a, int64_t lda, double _Complex *tau, + double _Complex *c, int64_t ldc); + +void onemklSorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float *a, + int64_t lda, float *tau); +void onemklDorgqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double *a, + int64_t lda, double *tau); +void onemklCungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, float _Complex *a, + int64_t lda, float _Complex *tau); +void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double _Complex *a, + int64_t lda, double _Complex *tau); + +void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); +void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); +void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); +void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); + +void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); +void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); +void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); +void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes); + void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau); void onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, @@ -91,11 +131,6 @@ void onemklCgetri(syclQueue_t device_queue, int64_t n, void onemklZgetri(syclQueue_t device_queue, int64_t n, double _Complex *a, int64_t lda, int64_t *ipiv); -// XXX: how to expose half in C? -// int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, -// onemklTranspose transB, int64_t m, int64_t n, int64_t k, -// half alpha, const half *A, int64_t lda, const half *B, -// int64_t ldb, half beta, half *C, int64_t ldc); int onemklSgemm(syclQueue_t device_queue, onemklTranspose transA, onemklTranspose transB, int64_t m, int64_t n, int64_t k, float alpha, const float *A, int64_t lda, const float *B,