Skip to content

Commit

Permalink
[oneMKL] Interface batched version of lapack routines
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Apr 3, 2024
1 parent 5ee8dd9 commit 77c51eb
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 45 deletions.
116 changes: 115 additions & 1 deletion lib/mkl/wrappers_lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ for (bname, fname, elty) in ((:onemklSpotri_scratchpad_size, :onemklSpotri, :Flo
end
end

#sytrf
# sytrf
for (bname, fname, elty) in ((:onemklSsytrf_scratchpad_size, :onemklSsytrf, :Float32),
(:onemklDsytrf_scratchpad_size, :onemklDsytrf, :Float64),
(:onemklCsytrf_scratchpad_size, :onemklCsytrf, :ComplexF32),
Expand Down Expand Up @@ -402,6 +402,62 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :onemklSsygvd_scratchpad_si
end
end

# potrf_batch
for (bname, fname, elty) in ((:onemklSpotrf_batch_scratchpad_size, :onemklSpotrf_batch, :Float32),
(:onemklDpotrf_batch_scratchpad_size, :onemklDpotrf_batch, :Float64),
(:onemklCpotrf_batch_scratchpad_size, :onemklCpotrf_batch, :ComplexF32),
(:onemklZpotrf_batch_scratchpad_size, :onemklZpotrf_batch, :ComplexF64))
@eval begin
function potrf_batched!(A::Vector{<:oneMatrix{$elty}})
group_count = length(A)
group_sizes = ones(Int64, group_count)
uplo = [ONEMKL_UPLO_LOWER for i=1:group_count]
n = [checksquare(A[i]) for i=1:group_count]
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
Aptrs = unsafe_batch(A)

queue = global_queue(context(A[1]), device(A[1]))
scratchpad_size = $bname(sycl_queue(queue), uplo, n, lda, group_count, group_sizes)
scratchpad = oneVector{$elty}(undef, scratchpad_size)
$fname(sycl_queue(queue), uplo, n, Aptrs, lda, group_count, group_sizes, scratchpad, scratchpad_size)

unsafe_free!(Aptrs)

return A
end
end
end

# potrs_batch
for (bname, fname, elty) in ((:onemklSpotrs_batch_scratchpad_size, :onemklSpotrs_batch, :Float32),
(:onemklDpotrs_batch_scratchpad_size, :onemklDpotrs_batch, :Float64),
(:onemklCpotrs_batch_scratchpad_size, :onemklCpotrs_batch, :ComplexF32),
(:onemklZpotrs_batch_scratchpad_size, :onemklZpotrs_batch, :ComplexF64))
@eval begin
function potrs_batched!(A::Vector{<:oneMatrix{$elty}}, B::Vector{<:oneMatrix{$elty}})
group_count = length(A)
group_sizes = ones(Int64, group_count)
uplo = [ONEMKL_UPLO_LOWER for i=1:group_count]
n = [checksquare(A[i]) for i=1:group_count]
nrhs = [size(B[i], 2) for i=1:group_count]
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
ldb = [max(1, stride(B[i], 2)) for i=1:group_count]
Aptrs = unsafe_batch(A)
Bptrs = unsafe_batch(B)

queue = global_queue(context(A[1]), device(A[1]))
scratchpad_size = $bname(sycl_queue(queue), uplo, n, nrhs, lda, ldb, group_count, group_sizes)
scratchpad = oneVector{$elty}(undef, scratchpad_size)
$fname(sycl_queue(queue), uplo, n, nrhs, Aptrs, lda, Bptrs, ldb, group_count, group_sizes, scratchpad, scratchpad_size)

unsafe_free!(Aptrs)
unsafe_free!(Bptrs)

return A
end
end
end

# getrf_batch
for (bname, fname, elty) in ((:onemklSgetrf_batch_scratchpad_size, :onemklSgetrf_batch, :Float32),
(:onemklDgetrf_batch_scratchpad_size, :onemklDgetrf_batch, :Float64),
Expand Down Expand Up @@ -490,6 +546,64 @@ for (bname, fname, elty) in ((:onemklSgetri_batch_scratchpad_size, :onemklSgetri
end
end

# geqrf_batch
for (bname, fname, elty) in ((:onemklSgeqrf_batch_scratchpad_size, :onemklSgeqrf_batch, :Float32),
(:onemklDgeqrf_batch_scratchpad_size, :onemklDgeqrf_batch, :Float64),
(:onemklCgeqrf_batch_scratchpad_size, :onemklCgeqrf_batch, :ComplexF32),
(:onemklZgeqrf_batch_scratchpad_size, :onemklZgeqrf_batch, :ComplexF64))
@eval begin
function geqrf_batched!(A::Vector{<:oneMatrix{$elty}})
group_count = length(A)
group_sizes = ones(Int64, group_count)
m = [size(A[i], 1) for i=1:group_count]
n = [size(A[i], 2) for i=1:group_count]
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
tau = [oneVector{$elty}(undef, min(m[i], n[i])) for i=1:group_count]
Aptrs = unsafe_batch(A)
tauptrs = unsafe_batch(tau)

queue = global_queue(context(A[1]), device(A[1]))
scratchpad_size = $bname(sycl_queue(queue), m, n, lda, group_count, group_sizes)
scratchpad = oneVector{$elty}(undef, scratchpad_size)
$fname(sycl_queue(queue), m, n, Aptrs, lda, tauptrs, group_count, group_sizes, scratchpad, scratchpad_size)

unsafe_free!(Aptrs)
unsafe_free!(tauptrs)

return tau, A
end
end
end

# orgqr_batch and ungqr_batch
for (bname, fname, elty) in ((:onemklSorgqr_batch_scratchpad_size, :onemklSorgqr_batch, :Float32),
(:onemklDorgqr_batch_scratchpad_size, :onemklDorgqr_batch, :Float64),
(:onemklCungqr_batch_scratchpad_size, :onemklCungqr_batch, :ComplexF32),
(:onemklZungqr_batch_scratchpad_size, :onemklZungqr_batch, :ComplexF64))
@eval begin
function orgqr_batched!(A::Vector{<:oneMatrix{$elty}}, tau::Vector{<:oneVector{$elty}})
group_count = length(A)
group_sizes = ones(Int64, group_count)
m = [size(A[i], 1) for i=1:group_count]
n = [size(A[i], 2) for i=1:group_count]
k = [min(m[i], n[i]) for i=1:group_count]
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
Aptrs = unsafe_batch(A)
tauptrs = unsafe_batch(tau)

queue = global_queue(context(A[1]), device(A[1]))
scratchpad_size = $bname(sycl_queue(queue), m, n, k, lda, group_count, group_sizes)
scratchpad = oneVector{$elty}(undef, scratchpad_size)
$fname(sycl_queue(queue), m, n, k, Aptrs, lda, tauptrs, group_count, group_sizes, scratchpad, scratchpad_size)

unsafe_free!(Aptrs)
unsafe_free!(tauptrs)

return A
end
end
end

# LAPACK
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
@eval begin
Expand Down
66 changes: 33 additions & 33 deletions lib/support/liboneapi_support.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4443,87 +4443,87 @@ function onemklSpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_si
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklSpotrf_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
a::Ptr{Ptr{Cfloat}}, lda::Ptr{Int64},
a::ZePtr{Ptr{Cfloat}}, lda::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{Cfloat},
scratchpad::ZePtr{Cfloat},
scratchpad_size::Int64)::Cint
end

function onemklDpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_sizes,
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklDpotrf_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
a::Ptr{Ptr{Cdouble}}, lda::Ptr{Int64},
a::ZePtr{Ptr{Cdouble}}, lda::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{Cdouble},
scratchpad::ZePtr{Cdouble},
scratchpad_size::Int64)::Cint
end

function onemklCpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_sizes,
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklCpotrf_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
a::ZePtr{Ptr{ComplexF32}}, lda::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF32},
scratchpad_size::Int64)::Cint
end

function onemklZpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_sizes,
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklZpotrf_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
a::ZePtr{Ptr{ComplexF64}}, lda::Ptr{Int64},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF64},
scratchpad_size::Int64)::Cint
end

function onemklSpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklSpotrs_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cfloat}},
lda::Ptr{Int64}, b::Ptr{Ptr{Cfloat}},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cfloat}},
lda::Ptr{Int64}, b::ZePtr{Ptr{Cfloat}},
ldb::Ptr{Int64}, group_count::Int64,
group_sizes::Ptr{Int64},
scratchpad::Ptr{Cfloat},
scratchpad::ZePtr{Cfloat},
scratchpad_size::Int64)::Cint
end

function onemklDpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklDpotrs_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cdouble}},
lda::Ptr{Int64}, b::Ptr{Ptr{Cdouble}},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cdouble}},
lda::Ptr{Int64}, b::ZePtr{Ptr{Cdouble}},
ldb::Ptr{Int64}, group_count::Int64,
group_sizes::Ptr{Int64},
scratchpad::Ptr{Cdouble},
scratchpad::ZePtr{Cdouble},
scratchpad_size::Int64)::Cint
end

function onemklCpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklCpotrs_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
lda::Ptr{Int64}, b::Ptr{Ptr{ComplexF32}},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF32}},
lda::Ptr{Int64}, b::ZePtr{Ptr{ComplexF32}},
ldb::Ptr{Int64}, group_count::Int64,
group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF32},
scratchpad_size::Int64)::Cint
end

function onemklZpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
group_sizes, scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklZpotrs_batch(device_queue::syclQueue_t,
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
lda::Ptr{Int64}, b::Ptr{Ptr{ComplexF32}},
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF64}},
lda::Ptr{Int64}, b::ZePtr{Ptr{ComplexF64}},
ldb::Ptr{Int64}, group_count::Int64,
group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF64},
scratchpad_size::Int64)::Cint
end

Expand Down Expand Up @@ -4697,43 +4697,43 @@ function onemklSorgqr_batch(device_queue, m, n, k, a, lda, tau, group_count, gro
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklSorgqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
n::Ptr{Int64}, k::Ptr{Int64},
a::Ptr{Ptr{Cfloat}}, lda::Ptr{Int64},
tau::Ptr{Ptr{Cfloat}}, group_count::Int64,
a::ZePtr{Ptr{Cfloat}}, lda::Ptr{Int64},
tau::ZePtr{Ptr{Cfloat}}, group_count::Int64,
group_sizes::Ptr{Int64},
scratchpad::Ptr{Cfloat},
scratchpad::ZePtr{Cfloat},
scratchpad_size::Int64)::Cint
end

function onemklDorgqr_batch(device_queue, m, n, k, a, lda, tau, group_count, group_sizes,
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklDorgqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
n::Ptr{Int64}, k::Ptr{Int64},
a::Ptr{Ptr{Cdouble}}, lda::Ptr{Int64},
tau::Ptr{Ptr{Cdouble}}, group_count::Int64,
group_sizes::Ptr{Int64},
scratchpad::Ptr{Cdouble},
a::ZePtr{Ptr{Cdouble}}, lda::Ptr{Int64},
tau::ZePtr{Ptr{Cdouble}},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::ZePtr{Cdouble},
scratchpad_size::Int64)::Cint
end

function onemklCungqr_batch(device_queue, m, n, k, a, lda, tau, group_count, group_sizes,
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklCungqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
n::Ptr{Int64}, k::Ptr{Int64},
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
tau::Ptr{Ptr{ComplexF32}},
a::ZePtr{Ptr{ComplexF32}}, lda::Ptr{Int64},
tau::ZePtr{Ptr{ComplexF32}},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF32},
scratchpad_size::Int64)::Cint
end

function onemklZungqr_batch(device_queue, m, n, k, a, lda, tau, group_count, group_sizes,
scratchpad, scratchpad_size)
@ccall liboneapi_support.onemklZungqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
n::Ptr{Int64}, k::Ptr{Int64},
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
tau::Ptr{Ptr{ComplexF32}},
a::ZePtr{Ptr{ComplexF64}}, lda::Ptr{Int64},
tau::ZePtr{Ptr{ComplexF64}},
group_count::Int64, group_sizes::Ptr{Int64},
scratchpad::Ptr{ComplexF32},
scratchpad::ZePtr{ComplexF64},
scratchpad_size::Int64)::Cint
end

Expand Down
19 changes: 19 additions & 0 deletions res/support.toml
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,25 @@ use_ccall_macro = true
6 = "ZePtr{Ptr{T}}"
9 = "ZePtr{T}"

[api.onemklXorgqr_batch.argtypes]
5 = "ZePtr{Ptr{T}}"
7 = "ZePtr{Ptr{T}}"
10 = "ZePtr{T}"

[api.onemklXungqr_batch.argtypes]
5 = "ZePtr{Ptr{T}}"
7 = "ZePtr{Ptr{T}}"
10 = "ZePtr{T}"

[api.onemklXpotrf_batch.argtypes]
4 = "ZePtr{Ptr{T}}"
8 = "ZePtr{T}"

[api.onemklXpotrs_batch.argtypes]
5 = "ZePtr{Ptr{T}}"
7 = "ZePtr{Ptr{T}}"
11 = "ZePtr{T}"

[api.onemklXsyevd.argtypes]
5 = "ZePtr{T}"
7 = "ZePtr{T}"
Expand Down
Loading

0 comments on commit 77c51eb

Please sign in to comment.