Skip to content

Commit

Permalink
Add getrf_batch and getri_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 18, 2023
1 parent 6342601 commit b84d61c
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 91 deletions.
166 changes: 83 additions & 83 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,89 +283,89 @@ extern "C" void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int
__FORCE_MKL_FLUSH__(status);
}

// extern "C" void onemklSgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<float>(device_queue->val, m, n, lda, group_count, group_sizes);
// auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }

// extern "C" void onemklDgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<double>(device_queue->val, m, n, lda, group_count, group_sizes);
// auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[{0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }

// extern "C" void onemklCgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<std::complex<float>>(device_queue->val, m, n, lda, group_count, group_sizes);
// auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast<std::complex<float> **>(&a[0]), lda,
// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }

// extern "C" void onemklZgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<std::complex<double>>(device_queue->val, m, n, lda, group_count, group_sizes);
// auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
// auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast<std::complex<double> **>(&a[0]), lda,
// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }

// extern "C" void onemklSgetriBatched(syclQueue_t device_queue, int64_t n, float **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<float>(device_queue->val, n, lda, group_count, group_sizes);
// auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }

// extern "C" void onemklDgetriBatched(syclQueue_t device_queue, int64_t n, double **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<double>(device_queue->val, n, lda, group_count, group_sizes);
// auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }

// extern "C" void onemklCgetriBatched(syclQueue_t device_queue, int64_t n, float _Complex **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<std::complex<float>>(device_queue->val, n, lda, group_count, group_sizes);
// auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast<std::complex<float> **>(&a[0]), lda,
// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }

// extern "C" void onemklZgetriBatched(syclQueue_t device_queue, int64_t n, double _Complex **a, int64_t *lda,
// int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
// auto device = device_queue->val.get_device();
// auto context = device_queue->val.get_context();
// auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<std::complex<double>>(device_queue->val, n, lda, group_count, group_sizes);
// auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
// auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast<std::complex<double> **>(&a[0]), lda,
// &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
// __FORCE_MKL_FLUSH__(status);
// }
extern "C" void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<float>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<double>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<std::complex<float>>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast<std::complex<float> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<std::complex<double>>(device_queue->val, m, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
auto status = oneapi::mkl::lapack::getrf_batch(device_queue->val, m, n, reinterpret_cast<std::complex<double> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<float>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (float *) malloc_device(scratchpad_size * sizeof(float), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<double>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (double *) malloc_device(scratchpad_size * sizeof(double), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, &a[0], lda, &ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<std::complex<float>>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<float> *) malloc_device(scratchpad_size * sizeof(std::complex<float>), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast<std::complex<float> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes) {
auto device = device_queue->val.get_device();
auto context = device_queue->val.get_context();
auto scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<std::complex<double>>(device_queue->val, n, lda, group_count, group_sizes);
auto scratchpad = (std::complex<double> *) malloc_device(scratchpad_size * sizeof(std::complex<double>), device, context);
auto status = oneapi::mkl::lapack::getri_batch(device_queue->val, n, reinterpret_cast<std::complex<double> **>(&a[0]), lda,
&ipiv[0], group_count, group_sizes, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n,
float *a, int64_t lda, float *tau) {
Expand Down
16 changes: 8 additions & 8 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,22 @@ void onemklCungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, flo
void onemklZungqr(syclQueue_t device_queue, int64_t m, int64_t n, int64_t k, double _Complex *a,
int64_t lda, double _Complex *tau);

void onemklSgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float **a, int64_t *lda,
void onemklSgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklDgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double **a, int64_t *lda,
void onemklDgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklCgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex **a, int64_t *lda,
void onemklCgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, float _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklZgetrfBatched(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex **a, int64_t *lda,
void onemklZgetrfBatched(syclQueue_t device_queue, int64_t *m, int64_t *n, double _Complex **a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);

void onemklSgetriBatched(syclQueue_t device_queue, int64_t n, float *a, int64_t *lda,
void onemklSgetriBatched(syclQueue_t device_queue, int64_t *n, float *a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklDgetriBatched(syclQueue_t device_queue, int64_t n, double *a, int64_t *lda,
void onemklDgetriBatched(syclQueue_t device_queue, int64_t *n, double *a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklCgetriBatched(syclQueue_t device_queue, int64_t n, float _Complex *a, int64_t *lda,
void onemklCgetriBatched(syclQueue_t device_queue, int64_t *n, float _Complex *a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);
void onemklZgetriBatched(syclQueue_t device_queue, int64_t n, double _Complex *a, int64_t *lda,
void onemklZgetriBatched(syclQueue_t device_queue, int64_t *n, double _Complex *a, int64_t *lda,
int64_t **ipiv, int64_t group_count, int64_t *group_sizes);

void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n,
Expand Down

0 comments on commit b84d61c

Please sign in to comment.