From b84d61c0226ed092dfe3a20c070d2af82263f244 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 18 Oct 2023 15:37:31 -0500 Subject: [PATCH] Add getrf_batch and getri_batch --- deps/src/onemkl.cpp | 166 ++++++++++++++++++++++---------------------- deps/src/onemkl.h | 16 ++--- 2 files changed, 91 insertions(+), 91 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 5e650dfc..71b76eea 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -283,89 +283,89 @@ extern "C" void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int __FORCE_MKL_FLUSH__(status); } -// extern "C" void onemklSgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size(device_queue->val, m, n, lda, group_count, group_sizes); -// auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); -// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } - -// extern "C" void onemklDgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size(device_queue->val, m, n, lda, group_count, group_sizes); -// auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); -// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[{0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } - -// extern "C" void onemklCgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size>(device_queue->val, m, n, lda, group_count, group_sizes); -// auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); -// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast **>(&a[0]), lda, -// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } - -// extern "C" void onemklZgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size>(device_queue->val, m, n, lda, group_count, group_sizes); -// auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); -// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast **>(&a[0]), lda, -// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } - -// extern "C" void onemklSgetriBatched(syclQueue_t device_queue, int64_t n, float **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size(device_queue->val, n, lda, group_count, group_sizes); -// auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); -// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } - -// extern "C" void onemklDgetriBatched(syclQueue_t device_queue, int64_t n, double **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size(device_queue->val, n, lda, group_count, group_sizes); -// auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); -// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } - -// extern "C" void onemklCgetriBatched(syclQueue_t device_queue, int64_t n, float _Complex **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size>(device_queue->val, n, lda, group_count, group_sizes); -// auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); -// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast **>(&a[0]), lda, -// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } - -// extern "C" void onemklZgetriBatched(syclQueue_t device_queue, int64_t n, double _Complex **a, int64_t *lda, -// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { -// auto device = device_queue->val.get_device(); -// auto context = device_queue->val.get_context(); -// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size>(device_queue->val, n, lda, group_count, group_sizes); -// auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); -// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast **>(&a[0]), lda, -// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); -// __FORCE_MKL_FLUSH__(status); -// } +extern "C" void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size>(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size>(device_queue->val, m, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size>(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} + +extern "C" void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex **a, int64_t *lda, + int64_t **ipiv, int64_t group_count, int64_t *group_sizes) { + auto device = device_queue->val.get_device(); + auto context = device_queue->val.get_context(); + auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size>(device_queue->val, n, lda, group_count, group_sizes); + auto scratchpad = (std::complex *) malloc_device(scratchpad_size * sizeof(std::complex), device, context); + auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast **>(&a[0]), lda, + &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size); + __FORCE_MKL_FLUSH__(status); +} extern "C" void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau) { diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 0bf37d70..a500782c 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -52,22 +52,22 @@ void onemklCungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, flo void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double _Complex *a, int64_t lda, double _Complex *tau); -void onemklSgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float **a, int64_t *lda, +void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); -void onemklDgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double **a, int64_t *lda, +void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); -void onemklCgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex **a, int64_t *lda, +void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); -void onemklZgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex **a, int64_t *lda, +void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); -void onemklSgetriBatched(syclQueue_t device_queue, int64_t n, float *a, int64_t *lda, +void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float *a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); -void onemklDgetriBatched(syclQueue_t device_queue, int64_t n, double *a, int64_t *lda, +void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double *a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); -void onemklCgetriBatched(syclQueue_t device_queue, int64_t n, float _Complex *a, int64_t *lda, +void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex *a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); -void onemklZgetriBatched(syclQueue_t device_queue, int64_t n, double _Complex *a, int64_t *lda, +void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex *a, int64_t *lda, int64_t **ipiv, int64_t group_count, int64_t *group_sizes); void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n,