Skip to content

Commit

Permalink
Remove high-level dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Sep 18, 2024
1 parent 12ab014 commit bef7991
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
27 changes: 18 additions & 9 deletions lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -574,16 +574,28 @@ function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
return C
end

"""
y = gemv(transa, alpha, A, x, index, algo)
Perform a product between a `CuSparseMatrix` and a `CuSparseVector`, returning a `CuSparseVector`.
This function should only be used for highly sparse matrices and vectors, as the result is expected
to have many non-zeros in practice.
For this reason, high-level functions like `mul!` and `*` internally convert the sparse vector into a
dense vector to use a more efficient CUSPARSE routine.
Supported formats for the sparse matrix are `CuSparseMatrixCSC` and `CuSparseMatrixCSR`.
"""
function gemv end

function gemv(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
x::CuSparseVector{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
m, n = size(A)
p = length(x)
p == n || throw(DimensionMismatch("dimensions must match: x has length $p, A has length $m × $n"))
# we model x as a CuSparseMatrixCSC with one column.
rowPtrB = CuVector{Int32}([1; nnz(x)+1])
B = CuSparseMatrixCSC(rowPtrB, x.iPtr, nonzeros(x), (n,1))
B = CuSparseMatrixCSC(x)
C = gemm(transa, 'N', alpha, A, B, index, algo)
y = CuSparseVector(C.rowVal, C.nzVal, m)
y = CuSparseVector(C)
return y
end

Expand All @@ -593,12 +605,9 @@ function gemv(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
p = length(x)
p == n || throw(DimensionMismatch("dimensions must match: x has length $p, A has length $m × $n"))
# we model x as a CuSparseMatrixCSR with one column.
rowPtrB = CuVector{Int32}([1; nnz(x)+1])
Btmp = CuSparseMatrixCSC(rowPtrB, x.iPtr, nonzeros(x), (n,1))
B = CuSparseMatrixCSR(Btmp)
Ctmp = gemm(transa, 'N', alpha, A, B, index, algo)
C = CuSparseMatrixCSC(Ctmp)
y = CuSparseVector(C.rowVal, C.nzVal, m)
B = CuSparseMatrixCSR(x)
C = gemm(transa, 'N', alpha, A, B, index, algo)
y = CuSparseVector(C)
return y
end

Expand Down
6 changes: 0 additions & 6 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,6 @@ for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
end
end

for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
@eval function LinearAlgebra.:(*)(A::$SparseMatrixType{T}, b::CuSparseVector{T}) where {T <: BlasFloat}
gemv('N', one(T), A, b, 'O')
end
end

function LinearAlgebra.:(*)(A::CuSparseMatrixCOO{T}, B::CuSparseMatrixCOO{T}) where {T <: BlasFloat}
A_csr = CuSparseMatrixCSR(A)
B_csr = CuSparseMatrixCSR(B)
Expand Down
13 changes: 13 additions & 0 deletions test/libraries/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,19 @@ for SparseMatrixType in keys(SPGEMM_ALGOS)
end
end
end

@testset "gemv $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
for (transa, opa) in [('N', identity)]
A = sprand(T,25,10,0.2)
b = sprand(T,10,0.3)
dA = SparseMatrixType(A)
db = CuSparseVector(b)
alpha = rand(T)
y = alpha * opa(A) * b
dy = gemv(transa, alpha, dA, db, 'O', algo)
@test collect(dy) y
end
end
end

if CUSPARSE.version() >= v"11.4.1"
Expand Down

0 comments on commit bef7991

Please sign in to comment.