Skip to content

Commit

Permalink
[CUSOLVER] Interface XsyevBatched (#2577)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Dec 13, 2024
1 parent 860eb88 commit ca8f6cf
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
40 changes: 40 additions & 0 deletions lib/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ end

# Xlarft!
function larft!(direct::Char, storev::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t::StridedCuMatrix{T}) where {T <: BlasFloat}
CUSOLVER.version() < v"11.6.0" && throw(ErrorException("This operation is not supported by the current CUDA version."))
n, k = size(v)
ktau = length(tau)
mt, nt = size(t)
Expand Down Expand Up @@ -449,6 +450,7 @@ end

# Xgeev
function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
n = checksquare(A)
VL = if jobvl == 'V'
CuMatrix{T}(undef, n, n)
Expand Down Expand Up @@ -492,6 +494,44 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla
return W, VL, VR
end

# XsyevBatched
function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
chkuplo(uplo)
n, num_matrices = size(A)
batch_size = num_matrices ÷ n
R = real(T)
lda = max(1, stride(A, 2))
W = CuVector{R}(undef, n * batch_size)
params = CuSolverParameters()
dh = dense_handle()
resize!(dh.info, batch_size)

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
cusolverDnXsyevBatched_bufferSize(dh, params, jobz, uplo, n,
T, A, lda, R, W, T, out_gpu, out_cpu, batch_size)
out_gpu[], out_cpu[]
end
with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu
cusolverDnXsyevBatched(dh, params, jobz, uplo, n, T, A,
lda, R, W, T, buffer_gpu, sizeof(buffer_gpu),
buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size)
end

info = @allowscalar collect(dh.info)
for i = 1:batch_size
chkargsok(info[i] |> BlasInt)
end

if jobz == 'N'
return W
elseif jobz == 'V'
return W, A
end
end

# LAPACK
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
@eval begin
Expand Down
30 changes: 30 additions & 0 deletions test/libraries/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,36 @@ p = 5
end
end
end

@testset "syevBatched!" begin
batch_size = 5
for uplo in ('L', 'U')
(uplo == 'L') && (elty == ComplexF32) && continue

A = rand(elty, n, n * batch_size)
B = rand(elty, n, n * batch_size)
for i = 1:batch_size
S = rand(elty,n,n)
S = S * S' + I
B[:,(i-1)*n+1:i*n] .= S
S = uplo == 'L' ? tril(S) : triu(S)
A[:,(i-1)*n+1:i*n] .= S
end
d_A = CuMatrix(A)
d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A)
W = collect(d_W)
V = collect(d_V)
for i = 1:batch_size
Bᵢ = B[:,(i-1)*n+1:i*n]
Wᵢ = Diagonal(W[(i-1)*n+1:i*n])
Vᵢ = V[:,(i-1)*n+1:i*n]
@test Bᵢ * Vᵢ Vᵢ * Diagonal(Wᵢ)
end

d_A = CuMatrix(A)
d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A)
end
end
end

if CUSOLVER.version() >= v"11.6.0"
Expand Down

0 comments on commit ca8f6cf

Please sign in to comment.