Skip to content

Commit

Permalink
Implement cublasSgetrsBatched.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Feb 18, 2024
1 parent 8f3c129 commit ad970a7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
14 changes: 13 additions & 1 deletion zluda_blas/src/cublas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4337,7 +4337,19 @@ pub unsafe extern "system" fn cublasSgetrsBatched(
info: *mut ::std::os::raw::c_int,
batchSize: ::std::os::raw::c_int,
) -> cublasStatus_t {
crate::unsupported()
crate::sgetrs_batched(
handle,
trans,
n,
nrhs,
Aarray,
lda,
devIpiv,
Barray,
ldb,
info,
batchSize,
)
}

#[no_mangle]
Expand Down
42 changes: 40 additions & 2 deletions zluda_blas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ pub use cublas::*;
use cuda_types::*;
use rocblas_sys::*;
use rocsolver_sys::{
rocsolver_cgetrf_batched, rocsolver_cgetri_batched, rocsolver_cgetri_outofplace_batched,
rocsolver_zgetrf_batched, rocsolver_zgetri_batched, rocsolver_zgetri_outofplace_batched,
rocsolver_cgetrf_batched, rocsolver_cgetri_batched, rocsolver_cgetri_outofplace_batched, rocsolver_sgetrs_batched, rocsolver_zgetrf_batched, rocsolver_zgetri_batched, rocsolver_zgetri_outofplace_batched
};
use std::{mem, ptr};

Expand Down Expand Up @@ -83,6 +82,15 @@ fn op_from_cuda(trans: cublasOperation_t) -> rocblas_operation {
}
}

fn op_from_cuda_for_solver(trans: cublasOperation_t) -> rocsolver_sys::rocblas_operation {
match trans {
cublasOperation_t::CUBLAS_OP_N => rocsolver_sys::rocblas_operation::rocblas_operation_none,
cublasOperation_t::CUBLAS_OP_T => rocsolver_sys::rocblas_operation::rocblas_operation_transpose,
cublasOperation_t::CUBLAS_OP_C => rocsolver_sys::rocblas_operation::rocblas_operation_conjugate_transpose,
_ => panic!(),
}
}

unsafe fn destroy(handle: cublasHandle_t) -> cublasStatus_t {
to_cuda(rocblas_destroy_handle(handle as _))
}
Expand Down Expand Up @@ -656,6 +664,36 @@ unsafe fn cgetri_batched(
))
}

unsafe fn sgetrs_batched(
handle: *mut cublasContext,
trans: cublasOperation_t,
n: i32,
nrhs: i32,
a: *const *const f32,
lda: i32,
dev_ipiv: *const i32,
b: *const *mut f32,
ldb: i32,
info: *mut i32,
batch_size: i32,
) -> cublasStatus_t {
let trans = op_from_cuda_for_solver(trans);
let stride = n * nrhs;
to_cuda_solver(rocsolver_sgetrs_batched(
handle.cast(),
trans,
n,
nrhs,
a.cast(),
lda,
dev_ipiv,
stride as _,
b.cast(),
ldb,
batch_size,
))
}

unsafe fn dtrmm_v2(
handle: *mut cublasContext,
side: cublasSideMode_t,
Expand Down
2 changes: 1 addition & 1 deletion zluda_dnn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use types::*;

use hip_runtime_sys::*;
use miopen_sys::*;
use std::{mem, ptr::{self, null, null_mut}, alloc::{self, Layout}};
use std::{mem, ptr, alloc::{self, Layout}};

macro_rules! call {
($expr:expr) => {{
Expand Down

0 comments on commit ad970a7

Please sign in to comment.