Skip to content

Commit

Permalink
Rework matrix-matrix and matrix-vector multiplication.
Browse files Browse the repository at this point in the history
- wrap cusparseSpMM
- make the cusparseSpMV wrapper support CSC
- switch to 5-arg mul
  • Loading branch information
maleadt committed Sep 4, 2020
1 parent 158bdf6 commit e39398f
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 205 deletions.
103 changes: 79 additions & 24 deletions lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,39 +69,94 @@ end
Base.unsafe_convert(::Type{cusparseSpMatDescr_t}, desc::CuSparseMatrixDescriptor) = desc.handle


## SpMV

function mv!(
transa::SparseChar,
alpha::T,
A::CuSparseMatrixCSR{T},
X::CuVector{T},
beta::T,
Y::CuVector{T},
index::SparseChar
) where {T}
## API functions

function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixBSR{T},CuSparseMatrixCSR{T}}, X::CuVector{T},
beta::Number, Y::CuVector{T}, index::SparseChar) where {T}
m,n = size(A)

if transa == 'N'
chkmvdims(X,n,Y,m)
elseif transa == 'T' || transa == 'C'
chkmvdims(X,m,Y,n)
end

cusparseSpMV(handle(), transa, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseVectorDescriptor(X), T[beta], CuDenseVectorDescriptor(Y), T,
CUSPARSE_MV_ALG_DEFAULT, CU_NULL)

Y
end

function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, X::CuVector{T},
beta::Number, Y::CuVector{T}, index::SparseChar) where {T}
ctransa = 'N'
if transa == 'N'
ctransa = 'T'
end
if transa == 'T' || transa == 'C'
# TODO: conjugate transpose?

n,m = size(A)

if ctransa == 'N'
chkmvdims(X,n,Y,m)
elseif ctransa == 'T' || ctransa == 'C'
chkmvdims(X,m,Y,n)
end

cusparseSpMV(
handle(),
transa,
[alpha],
CuSparseMatrixDescriptor(A),
CuDenseVectorDescriptor(X),
[beta],
CuDenseVectorDescriptor(Y),
T,
CUSPARSE_MV_ALG_DEFAULT,
CU_NULL
)
cusparseSpMV(handle(), ctransa, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseVectorDescriptor(X), T[beta], CuDenseVectorDescriptor(Y), T,
CUSPARSE_MV_ALG_DEFAULT, CU_NULL)

Y
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
B::CuMatrix{T}, beta::Number, C::CuMatrix{T}, index::SparseChar) where {T}
m,k = size(A)
n = size(C)[2]

if transa == 'N' && transb == 'N'
chkmmdims(B,C,k,n,m,n)
elseif transa == 'N' && transb != 'N'
chkmmdims(B,C,n,k,m,n)
elseif transa != 'N' && transb == 'N'
chkmmdims(B,C,m,n,k,n)
elseif transa != 'N' && transb != 'N'
chkmmdims(B,C,n,m,k,n)
end

cusparseSpMM(handle(), transa, transb, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseMatrixDescriptor(B), T[beta], CuDenseMatrixDescriptor(C), T,
CUSPARSE_MM_ALG_DEFAULT, CU_NULL)

C
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
B::CuMatrix{T}, beta::Number, C::CuMatrix{T}, index::SparseChar) where {T}
ctransa = 'N'
if transa == 'N'
ctransa = 'T'
end
# TODO: conjugate transpose?

k,m = size(A)
n = size(C)[2]

if ctransa == 'N' && transb == 'N'
chkmmdims(B,C,k,n,m,n)
elseif ctransa == 'N' && transb != 'N'
chkmmdims(B,C,n,k,m,n)
elseif ctransa != 'N' && transb == 'N'
chkmmdims(B,C,m,n,k,n)
elseif ctransa != 'N' && transb != 'N'
chkmmdims(B,C,n,m,k,n)
end

cusparseSpMM(handle(), ctransa, transb, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseMatrixDescriptor(B), T[beta], CuDenseMatrixDescriptor(C), T,
CUSPARSE_MM_ALG_DEFAULT, CU_NULL)

C
end
50 changes: 32 additions & 18 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,43 @@
using LinearAlgebra
using LinearAlgebra: BlasFloat

Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('N',A,B,'O')
function mv_wrapper(transa::SparseChar, alpha::Number, A::CuSparseMatrix{T}, X::CuVector{T},
beta::Number, Y::CuVector{T}) where {T}
mv!(transa, alpha, A, X, beta, Y, 'O')
end

LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('T',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('C',alpha,parent(adjA),B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T},alpha::Number,beta::Number) where T = mv_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T},alpha::Number,beta::Number) where {T} = mv_wrapper('T',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T},alpha::Number,beta::Number) where {T} = mv_wrapper('C',alpha,parent(adjA),B,beta,C)

function mm_wrapper(transa::SparseChar, transb::SparseChar, alpha::Number,
A::CuSparseMatrix{T}, B::CuMatrix{T}, beta::Number, C::CuMatrix{T}) where {T}
if version() < v"10.3.1" && A isa CuSparseMatrixCSR
# generic mm! doesn't work on CUDA 10.1 with CSC matrices
return mm2!(transa, transb, alpha, A, B, beta, C, 'O')
end
mm!(transa, transb, alpha, A, B, beta, C, 'O')
end

LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},B::CuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('N','N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, <:CuMatrix{T}},alpha::Number,beta::Number) where {T} = mm_wrapper('N','T',alpha,A,parent(transB),beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('T','N',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, <:CuMatrix{T}},alpha::Number,beta::Number) where {T} = mm_wrapper('T','T',alpha,parent(transA),parent(transB),beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('C','N',alpha,parent(adjA),B,beta,C)

LinearAlgebra.mul!(C::CuMatrix{T},A::HermOrSym{<:Number, <:CuSparseMatrix},B::CuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('T',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('C',alpha,parent(adjA),B,beta,C)

Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('N',A,B,'O')
Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}}, B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('T',parent(transA),B,'O')
Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}}, B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('T',parent(transA),B,'O')
Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')
Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')

LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector) where {T} = mv!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T}) where T = mv!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')

LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},B::CuMatrix{T}) where {T} = mm2!('N','N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, <:CuMatrix{T}}) where {T} = mm2!('N','T',one(T),A,parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T}) where {T} = mm2!('T','N',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, <:CuMatrix{T}}) where {T} = mm2!('T','T',one(T),parent(transA),parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T}) where {T} = mm2!('C','N',one(T),parent(adjA),B,zero(T),C,'O')

LinearAlgebra.mul!(C::CuMatrix{T},A::HermOrSym{<:Number, <:CuSparseMatrix},B::CuMatrix) where {T} = mm!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix) where {T} = mm!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix) where {T} = mm!('C',one(T),parent(adjA),B,zero(T),C,'O')

Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::CuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O')
Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}},B::CuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O')
Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}},B::CuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O')
Expand Down
8 changes: 4 additions & 4 deletions lib/cusparse/level2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ for (fname,elty) in ((:cusparseSbsrmv, :Float32),
(:cusparseZbsrmv, :ComplexF64))
@eval begin
function mv!(transa::SparseChar,
alpha::$elty,
alpha::Number,
A::CuSparseMatrixBSR{$elty},
X::CuVector{$elty},
beta::$elty,
beta::Number,
Y::CuVector{$elty},
index::SparseChar)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, index)
Expand All @@ -36,8 +36,8 @@ for (fname,elty) in ((:cusparseSbsrmv, :Float32),
chkmvdims(X,m,Y,n)
end
$fname(handle(), A.dir, transa, mb, nb,
A.nnz, [alpha], desc, A.nzVal, A.rowPtr,
A.colVal, A.blockDim, X, [beta], Y)
A.nnz, $elty[alpha], desc, A.nzVal, A.rowPtr,
A.colVal, A.blockDim, X, $elty[beta], Y)
Y
end
end
Expand Down
99 changes: 19 additions & 80 deletions lib/cusparse/level3.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
# sparse linear algebra functions that perform operations between sparse and (usually tall)
# dense matrices

export mm2!, mm2

"""
mm2!(transa::SparseChar, transb::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, B::CuMatrix, beta::BlasFloat, C::CuMatrix, index::SparseChar)
Multiply the sparse matrix `A` by the dense matrix `B`, filling in dense matrix `C`.
`C = alpha*op(A)*op(B) + beta*C`. `op(A)` can be nothing (`transa = N`), transpose
(`transa = T`), or conjugate transpose (`transa = C`), and similarly for `op(B)` and
`transb`.
"""
mm2!(transa::SparseChar, transb::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, B::CuMatrix, beta::BlasFloat, C::CuMatrix, index::SparseChar)
# bsrmm
for (fname,elty) in ((:cusparseSbsrmm, :Float32),
(:cusparseDbsrmm, :Float64),
(:cusparseCbsrmm, :ComplexF32),
(:cusparseZbsrmm, :ComplexF64))
@eval begin
function mm2!(transa::SparseChar,
transb::SparseChar,
alpha::$elty,
A::CuSparseMatrixBSR{$elty},
B::CuMatrix{$elty},
beta::$elty,
C::CuMatrix{$elty},
index::SparseChar)
function mm!(transa::SparseChar,
transb::SparseChar,
alpha::Number,
A::CuSparseMatrixBSR{$elty},
B::CuMatrix{$elty},
beta::Number,
C::CuMatrix{$elty},
index::SparseChar)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, index)
m,k = A.dims
mb = div(m,A.blockDim)
Expand All @@ -43,8 +33,8 @@ for (fname,elty) in ((:cusparseSbsrmm, :Float32),
ldc = max(1,stride(C,2))
$fname(handle(), A.dir,
transa, transb, mb, n, kb, A.nnz,
[alpha], desc, A.nzVal,A.rowPtr, A.colVal,
A.blockDim, B, ldb, [beta], C, ldc)
$elty[alpha], desc, A.nzVal,A.rowPtr, A.colVal,
A.blockDim, B, ldb, $elty[beta], C, ldc)
C
end
end
Expand All @@ -57,10 +47,10 @@ for (fname,elty) in ((:cusparseScsrmm2, :Float32),
@eval begin
function mm2!(transa::SparseChar,
transb::SparseChar,
alpha::$elty,
alpha::Number,
A::CuSparseMatrixCSR{$elty},
B::CuMatrix{$elty},
beta::$elty,
beta::Number,
C::CuMatrix{$elty},
index::SparseChar)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, index)
Expand All @@ -78,16 +68,16 @@ for (fname,elty) in ((:cusparseScsrmm2, :Float32),
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
$fname(handle(),
transa, transb, m, n, k, A.nnz, [alpha], desc,
A.nzVal, A.rowPtr, A.colVal, B, ldb, [beta], C, ldc)
transa, transb, m, n, k, A.nnz, $elty[alpha], desc,
A.nzVal, A.rowPtr, A.colVal, B, ldb, $elty[beta], C, ldc)
C
end
function mm2!(transa::SparseChar,
transb::SparseChar,
alpha::$elty,
alpha::Number,
A::CuSparseMatrixCSC{$elty},
B::CuMatrix{$elty},
beta::$elty,
beta::Number,
C::CuMatrix{$elty},
index::SparseChar)
ctransa = 'N'
Expand All @@ -109,64 +99,13 @@ for (fname,elty) in ((:cusparseScsrmm2, :Float32),
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
$fname(handle(),
ctransa, transb, m, n, k, A.nnz, [alpha], desc,
A.nzVal, A.colPtr, A.rowVal, B, ldb, [beta], C, ldc)
ctransa, transb, m, n, k, A.nnz, $elty[alpha], desc,
A.nzVal, A.colPtr, A.rowVal, B, ldb, $elty[beta], C, ldc)
C
end
end
end

for elty in (:Float32,:Float64,:ComplexF32,:ComplexF64)
@eval begin
function mm2(transa::SparseChar,
transb::SparseChar,
alpha::$elty,
A::Union{CuSparseMatrixCSR{$elty},CuSparseMatrixCSC{$elty},CuSparseMatrixBSR{$elty}},
B::CuMatrix{$elty},
beta::$elty,
C::CuMatrix{$elty},
index::SparseChar)
mm2!(transa,transb,alpha,A,B,beta,copy(C),index)
end
function mm2(transa::SparseChar,
transb::SparseChar,
A::Union{CuSparseMatrixCSR{$elty},CuSparseMatrixCSC{$elty},CuSparseMatrixBSR{$elty}},
B::CuMatrix{$elty},
beta::$elty,
C::CuMatrix{$elty},
index::SparseChar)
mm2(transa,transb,one($elty),A,B,beta,C,index)
end
function mm2(transa::SparseChar,
transb::SparseChar,
A::Union{CuSparseMatrixCSR{$elty},CuSparseMatrixCSC{$elty},CuSparseMatrixBSR{$elty}},
B::CuMatrix{$elty},
C::CuMatrix{$elty},
index::SparseChar)
mm2(transa,transb,one($elty),A,B,one($elty),C,index)
end
function mm2(transa::SparseChar,
transb::SparseChar,
alpha::$elty,
A::Union{CuSparseMatrixCSR{$elty},CuSparseMatrixCSC{$elty},CuSparseMatrixBSR{$elty}},
B::CuMatrix{$elty},
index::SparseChar)
m = transa == 'N' ? size(A)[1] : size(A)[2]
n = transb == 'N' ? size(B)[2] : size(B)[1]
mm2(transa,transb,alpha,A,B,zero($elty),CUDA.zeros($elty,(m,n)),index)
end
function mm2(transa::SparseChar,
transb::SparseChar,
A::Union{CuSparseMatrixCSR{$elty},CuSparseMatrixCSC{$elty},CuSparseMatrixBSR{$elty}},
B::CuMatrix{$elty},
index::SparseChar)
m = transa == 'N' ? size(A)[1] : size(A)[2]
n = transb == 'N' ? size(B)[2] : size(B)[1]
mm2(transa,transb,one($elty),A,B,zero($elty),CUDA.zeros($elty,(m,n)),index)
end
end
end

# bsrsm2
for (bname,aname,sname,elty) in ((:cusparseSbsrsm2_bufferSize, :cusparseSbsrsm2_analysis, :cusparseSbsrsm2_solve, :Float32),
(:cusparseDbsrsm2_bufferSize, :cusparseDbsrsm2_analysis, :cusparseDbsrsm2_solve, :Float64),
Expand Down
Loading

0 comments on commit e39398f

Please sign in to comment.