diff --git a/lib/mkl/wrappers_lapack.jl b/lib/mkl/wrappers_lapack.jl index 42b9062b..9bfd3449 100644 --- a/lib/mkl/wrappers_lapack.jl +++ b/lib/mkl/wrappers_lapack.jl @@ -431,6 +431,38 @@ for (bname, fname, elty) in ((:onemklSgetrf_batch_scratchpad_size, :onemklSgetrf end end +# getrs_batch +for (bname, fname, elty) in ((:onemklSgetrs_batch_scratchpad_size, :onemklSgetrs_batch, :Float32), + (:onemklDgetrs_batch_scratchpad_size, :onemklDgetrs_batch, :Float64), + (:onemklCgetrs_batch_scratchpad_size, :onemklCgetrs_batch, :ComplexF32), + (:onemklZgetrs_batch_scratchpad_size, :onemklZgetrs_batch, :ComplexF64)) + @eval begin + function getrs_batched!(A::Vector{<:oneMatrix{$elty}}, ipiv::Vector{<:oneVector{Int64}}, B::Vector{<:oneMatrix{$elty}}) + group_count = length(A) + group_sizes = ones(Int64, group_count) + trans = [ONEMKL_TRANSPOSE_NONTRANS for i=1:group_count] + n = [checksquare(A[i]) for i=1:group_count] + nrhs = [size(B[i], 2) for i=1:group_count] + lda = [max(1, stride(A[i], 2)) for i=1:group_count] + ldb = [max(1, stride(B[i], 2)) for i=1:group_count] + Aptrs = unsafe_batch(A) + Bptrs = unsafe_batch(B) + ipivptrs = unsafe_batch(ipiv) + + queue = global_queue(context(A[1]), device(A[1])) + scratchpad_size = $bname(sycl_queue(queue), trans, n, nrhs, lda, ldb, group_count, group_sizes) + scratchpad = oneVector{$elty}(undef, scratchpad_size) + $fname(sycl_queue(queue), trans, n, nrhs, Aptrs, lda, ipivptrs, Bptrs, ldb, group_count, group_sizes, scratchpad, scratchpad_size) + + unsafe_free!(Aptrs) + unsafe_free!(Bptrs) + unsafe_free!(ipivptrs) + + return B + end + end +end + # getri_batch for (bname, fname, elty) in ((:onemklSgetri_batch_scratchpad_size, :onemklSgetri_batch, :Float32), (:onemklDgetri_batch_scratchpad_size, :onemklDgetri_batch, :Float64), diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index 9ffc388b..17fc606f 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -4567,11 +4567,11 @@ function onemklSgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_sizes, scratchpad, scratchpad_size) @ccall liboneapi_support.onemklSgetrs_batch(device_queue::syclQueue_t, trans::Ptr{onemklTranspose}, n::Ptr{Int64}, - nrhs::Ptr{Int64}, a::Ptr{Ptr{Cfloat}}, - lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}}, - b::Ptr{Ptr{Cfloat}}, ldb::Ptr{Int64}, + nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cfloat}}, + lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}}, + b::ZePtr{Ptr{Cfloat}}, ldb::Ptr{Int64}, group_count::Int64, group_sizes::Ptr{Int64}, - scratchpad::Ptr{Cfloat}, + scratchpad::ZePtr{Cfloat}, scratchpad_size::Int64)::Cint end @@ -4579,11 +4579,11 @@ function onemklDgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_sizes, scratchpad, scratchpad_size) @ccall liboneapi_support.onemklDgetrs_batch(device_queue::syclQueue_t, trans::Ptr{onemklTranspose}, n::Ptr{Int64}, - nrhs::Ptr{Int64}, a::Ptr{Ptr{Cdouble}}, - lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}}, - b::Ptr{Ptr{Cdouble}}, ldb::Ptr{Int64}, + nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cdouble}}, + lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}}, + b::ZePtr{Ptr{Cdouble}}, ldb::Ptr{Int64}, group_count::Int64, group_sizes::Ptr{Int64}, - scratchpad::Ptr{Cdouble}, + scratchpad::ZePtr{Cdouble}, scratchpad_size::Int64)::Cint end @@ -4591,11 +4591,11 @@ function onemklCgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_sizes, scratchpad, scratchpad_size) @ccall liboneapi_support.onemklCgetrs_batch(device_queue::syclQueue_t, trans::Ptr{onemklTranspose}, n::Ptr{Int64}, - nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}}, - lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}}, - b::Ptr{Ptr{ComplexF32}}, ldb::Ptr{Int64}, + nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF32}}, + lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}}, + b::ZePtr{Ptr{ComplexF32}}, ldb::Ptr{Int64}, group_count::Int64, group_sizes::Ptr{Int64}, - scratchpad::Ptr{ComplexF32}, + scratchpad::ZePtr{ComplexF32}, scratchpad_size::Int64)::Cint end @@ -4603,11 +4603,11 @@ function onemklZgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_sizes, scratchpad, scratchpad_size) @ccall liboneapi_support.onemklZgetrs_batch(device_queue::syclQueue_t, trans::Ptr{onemklTranspose}, n::Ptr{Int64}, - nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}}, - lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}}, - b::Ptr{Ptr{ComplexF32}}, ldb::Ptr{Int64}, + nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF64}}, + lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}}, + b::ZePtr{Ptr{ComplexF64}}, ldb::Ptr{Int64}, group_count::Int64, group_sizes::Ptr{Int64}, - scratchpad::Ptr{ComplexF32}, + scratchpad::ZePtr{ComplexF64}, scratchpad_size::Int64)::Cint end diff --git a/res/support.toml b/res/support.toml index c84e69c2..b72862b2 100644 --- a/res/support.toml +++ b/res/support.toml @@ -550,6 +550,12 @@ use_ccall_macro = true 6 = "ZePtr{Ptr{Int64}}" 9 = "ZePtr{T}" +[api.onemklXgetrs_batch.argtypes] +5 = "ZePtr{Ptr{T}}" +7 = "ZePtr{Ptr{Int64}}" +8 = "ZePtr{Ptr{T}}" +12 = "ZePtr{T}" + [api.onemklXgetri_batch.argtypes] 3 = "ZePtr{Ptr{T}}" 5 = "ZePtr{Ptr{Int64}}" diff --git a/test/onemkl.jl b/test/onemkl.jl index 51b17c59..76459a93 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -1300,6 +1300,24 @@ end end end + @testset "getrs_batched!" begin + bA = [rand(elty, m, m) for i in 1:p] + bB = [rand(elty, m, n) for i in 1:p] + d_bA = oneMatrix{elty}[] + d_bB = oneMatrix{elty}[] + for i in 1:p + push!(d_bA, oneMatrix(bA[i])) + push!(d_bB, oneMatrix(bB[i])) + end + + d_ipiv, d_bA = oneMKL.getrf_batched!(d_bA) + d_bX = oneMKL.getrs_batched!(d_bA, d_ipiv, d_bB) + h_bX = [collect(d_bX[i]) for i in 1:p] + for i = 1:p + @test bA[i] * hbX[i] ≈ bB[i] + end + end + @testset "gebrd!" begin A = rand(elty,m,n) d_A = oneArray(A)