Skip to content

Commit

Permalink
[oneMKL] Interface getrs_batched!
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Apr 3, 2024
1 parent b6a393f commit 5ee8dd9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 16 deletions.
32 changes: 32 additions & 0 deletions lib/mkl/wrappers_lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,38 @@ for (bname, fname, elty) in ((:onemklSgetrf_batch_scratchpad_size, :onemklSgetrf
end
end

# getrs_batch
for (bname, fname, elty) in ((:onemklSgetrs_batch_scratchpad_size, :onemklSgetrs_batch, :Float32),
(:onemklDgetrs_batch_scratchpad_size, :onemklDgetrs_batch, :Float64),
(:onemklCgetrs_batch_scratchpad_size, :onemklCgetrs_batch, :ComplexF32),
(:onemklZgetrs_batch_scratchpad_size, :onemklZgetrs_batch, :ComplexF64))
@eval begin
function getrs_batched!(A::Vector{<:oneMatrix{$elty}}, ipiv::Vector{<:oneVector{Int64}}, B::Vector{<:oneMatrix{$elty}})
group_count = length(A)
group_sizes = ones(Int64, group_count)
trans = [ONEMKL_TRANSPOSE_NONTRANS for i=1:group_count]
n = [checksquare(A[i]) for i=1:group_count]
nrhs = [size(B[i], 2) for i=1:group_count]
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
ldb = [max(1, stride(B[i], 2)) for i=1:group_count]
Aptrs = unsafe_batch(A)
Bptrs = unsafe_batch(B)
ipivptrs = unsafe_batch(ipiv)

queue = global_queue(context(A[1]), device(A[1]))
scratchpad_size = $bname(sycl_queue(queue), trans, n, nrhs, lda, ldb, group_count, group_sizes)
scratchpad = oneVector{$elty}(undef, scratchpad_size)
$fname(sycl_queue(queue), trans, n, nrhs, Aptrs, lda, ipivptrs, Bptrs, ldb, group_count, group_sizes, scratchpad, scratchpad_size)

unsafe_free!(Aptrs)
unsafe_free!(Bptrs)
unsafe_free!(ipivptrs)

return B
end
end
end

# getri_batch
for (bname, fname, elty) in ((:onemklSgetri_batch_scratchpad_size, :onemklSgetri_batch, :Float32),
(:onemklDgetri_batch_scratchpad_size, :onemklDgetri_batch, :Float64),
Expand Down
32 changes: 16 additions & 16 deletions lib/support/liboneapi_support.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4567,47 +4567,47 @@ function onemklSgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklSgetrs_batch(device_queue::syclQueue_t,
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cfloat}},
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
b::Ptr{Ptr{Cfloat}}, ldb::Ptr{Int64},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cfloat}},
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
b::ZePtr{Ptr{Cfloat}}, ldb::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{Cfloat},
scratchpad::ZePtr{Cfloat},
scratchpad_size::Int64)::Cint
end

function onemklDgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_count,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklDgetrs_batch(device_queue::syclQueue_t,
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cdouble}},
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
b::Ptr{Ptr{Cdouble}}, ldb::Ptr{Int64},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cdouble}},
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
b::ZePtr{Ptr{Cdouble}}, ldb::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{Cdouble},
scratchpad::ZePtr{Cdouble},
scratchpad_size::Int64)::Cint
end

function onemklCgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_count,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklCgetrs_batch(device_queue::syclQueue_t,
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
b::Ptr{Ptr{ComplexF32}}, ldb::Ptr{Int64},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF32}},
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
b::ZePtr{Ptr{ComplexF32}}, ldb::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF32},
scratchpad_size::Int64)::Cint
end

function onemklZgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_count,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklZgetrs_batch(device_queue::syclQueue_t,
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
b::Ptr{Ptr{ComplexF32}}, ldb::Ptr{Int64},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF64}},
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
b::ZePtr{Ptr{ComplexF64}}, ldb::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF64},
scratchpad_size::Int64)::Cint
end

Expand Down
6 changes: 6 additions & 0 deletions res/support.toml
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,12 @@ use_ccall_macro = true
6 = "ZePtr{Ptr{Int64}}"
9 = "ZePtr{T}"

[api.onemklXgetrs_batch.argtypes]
5 = "ZePtr{Ptr{T}}"
7 = "ZePtr{Ptr{Int64}}"
8 = "ZePtr{Ptr{T}}"
12 = "ZePtr{T}"

[api.onemklXgetri_batch.argtypes]
3 = "ZePtr{Ptr{T}}"
5 = "ZePtr{Ptr{Int64}}"
Expand Down
18 changes: 18 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,24 @@ end
end
end

@testset "getrs_batched!" begin
bA = [rand(elty, m, m) for i in 1:p]
bB = [rand(elty, m, n) for i in 1:p]
d_bA = oneMatrix{elty}[]
d_bB = oneMatrix{elty}[]
for i in 1:p
push!(d_bA, oneMatrix(bA[i]))
push!(d_bB, oneMatrix(bB[i]))
end

d_ipiv, d_bA = oneMKL.getrf_batched!(d_bA)
d_bX = oneMKL.getrs_batched!(d_bA, d_ipiv, d_bB)
h_bX = [collect(d_bX[i]) for i in 1:p]
for i = 1:p
@test bA[i] * hbX[i] bB[i]
end
end

@testset "gebrd!" begin
A = rand(elty,m,n)
d_A = oneArray(A)
Expand Down

0 comments on commit 5ee8dd9

Please sign in to comment.