Skip to content

Commit

Permalink
Rework matrix-matrix and matrix-vector multiplication.
Browse files Browse the repository at this point in the history
- wrap cusparseSpMM
- make the cusparseSpMV wrapper support CSC
- switch to 5-arg mul
  • Loading branch information
maleadt committed Sep 3, 2020
1 parent 7c1e847 commit e113713
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 276 deletions.
165 changes: 125 additions & 40 deletions lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
mutable struct CuDenseVectorDescriptor
handle::cusparseDnVecDescr_t

function CuDenseVectorDescriptor(v::CuVector{T}) where {T}
vec_ref = Ref{cusparseDnVecDescr_t}()
cusparseCreateDnVec(vec_ref, length(v), v, T)
obj = new(vec_ref[])
function CuDenseVectorDescriptor(x::CuVector)
desc_ref = Ref{cusparseDnVecDescr_t}()
cusparseCreateDnVec(desc_ref, length(x), x, eltype(x))
obj = new(desc_ref[])
finalizer(cusparseDestroyDnVec, obj)
obj
end
Expand All @@ -17,61 +17,146 @@ end
Base.unsafe_convert(::Type{cusparseDnVecDescr_t}, desc::CuDenseVectorDescriptor) = desc.handle


## dense matrix descriptor

mutable struct CuDenseMatrixDescriptor
handle::cusparseDnMatDescr_t

function CuDenseMatrixDescriptor(x::CuMatrix)
desc_ref = Ref{cusparseDnMatDescr_t}()
cusparseCreateDnMat(desc_ref, size(x)..., stride(x,2), x, eltype(x), CUSPARSE_ORDER_COL)
obj = new(desc_ref[])
finalizer(cusparseDestroyDnMat, obj)
obj
end
end

Base.unsafe_convert(::Type{cusparseDnMatDescr_t}, desc::CuDenseMatrixDescriptor) = desc.handle


## sparse matrix descriptor

mutable struct CuSparseMatrixDescriptor
handle::cusparseSpMatDescr_t
end

function CuSparseMatrixDescriptor(A::CuSparseMatrixCSR{T}) where {T}
desc_ref = Ref{cusparseSpMatDescr_t}()
cusparseCreateCsr(
desc_ref,
A.dims..., length(A.nzVal),
A.rowPtr, A.colVal, A.nzVal,
eltype(A.rowPtr), eltype(A.colVal), 'O', eltype(A.nzVal)
)
obj = CuSparseMatrixDescriptor(desc_ref[])
finalizer(cusparseDestroySpMat, obj)
return obj
function CuSparseMatrixDescriptor(A::CuSparseMatrixCSR)
desc_ref = Ref{cusparseSpMatDescr_t}()
cusparseCreateCsr(
desc_ref,
A.dims..., length(A.nzVal),
A.rowPtr, A.colVal, A.nzVal,
eltype(A.rowPtr), eltype(A.colVal), 'O', eltype(A.nzVal)
)
obj = new(desc_ref[])
finalizer(cusparseDestroySpMat, obj)
return obj
end

function CuSparseMatrixDescriptor(A::CuSparseMatrixCSC)
desc_ref = Ref{cusparseSpMatDescr_t}()
cusparseCreateCsr(
desc_ref,
reverse(A.dims)..., length(A.nzVal),
A.colPtr, A.rowVal, A.nzVal,
eltype(A.colPtr), eltype(A.rowVal), 'O', eltype(A.nzVal)
)
obj = new(desc_ref[])
finalizer(cusparseDestroySpMat, obj)
return obj
end
end

Base.unsafe_convert(::Type{cusparseSpMatDescr_t}, desc::CuSparseMatrixDescriptor) = desc.handle


## SpMV

function mv!(
transa::SparseChar,
alpha::T,
A::CuSparseMatrixCSR{T},
X::CuVector{T},
beta::T,
Y::CuVector{T},
index::SparseChar
) where {T}
## API functions

function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixBSR{T},CuSparseMatrixCSR{T}}, X::CuVector{T},
beta::Number, Y::CuVector{T}, index::SparseChar) where {T}
m,n = size(A)

if transa == 'N'
chkmvdims(X,n,Y,m)
elseif transa == 'T' || transa == 'C'
chkmvdims(X,m,Y,n)
end

cusparseSpMV(handle(), transa, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseVectorDescriptor(X), T[beta], CuDenseVectorDescriptor(Y), T,
CUSPARSE_MV_ALG_DEFAULT, CU_NULL)

Y
end

function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, X::CuVector{T},
beta::Number, Y::CuVector{T}, index::SparseChar) where {T}
ctransa = 'N'
if transa == 'N'
ctransa = 'T'
end
if transa == 'T' || transa == 'C'
# TODO: conjugate transpose?

n,m = size(A)

if ctransa == 'N'
chkmvdims(X,n,Y,m)
elseif ctransa == 'T' || ctransa == 'C'
chkmvdims(X,m,Y,n)
end

cusparseSpMV(
handle(),
transa,
[alpha],
CuSparseMatrixDescriptor(A),
CuDenseVectorDescriptor(X),
[beta],
CuDenseVectorDescriptor(Y),
T,
CUSPARSE_MV_ALG_DEFAULT,
CU_NULL
)
cusparseSpMV(handle(), ctransa, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseVectorDescriptor(X), T[beta], CuDenseVectorDescriptor(Y), T,
CUSPARSE_MV_ALG_DEFAULT, CU_NULL)

Y
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
B::CuMatrix{T}, beta::Number, C::CuMatrix{T}, index::SparseChar) where {T}
m,k = size(A)
n = size(C)[2]

if transa == 'N' && transb == 'N'
chkmmdims(B,C,k,n,m,n)
elseif transa == 'N' && transb != 'N'
chkmmdims(B,C,n,k,m,n)
elseif transa != 'N' && transb == 'N'
chkmmdims(B,C,m,n,k,n)
elseif transa != 'N' && transb != 'N'
chkmmdims(B,C,n,m,k,n)
end

cusparseSpMM(handle(), transa, transb, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseMatrixDescriptor(B), T[beta], CuDenseMatrixDescriptor(C), T,
CUSPARSE_MM_ALG_DEFAULT, CU_NULL)

C
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
B::CuMatrix{T}, beta::Number, C::CuMatrix{T}, index::SparseChar) where {T}
ctransa = 'N'
if transa == 'N'
ctransa = 'T'
end
# TODO: conjugate transpose?

k,m = size(A)
n = size(C)[2]

if ctransa == 'N' && transb == 'N'
chkmmdims(B,C,k,n,m,n)
elseif ctransa == 'N' && transb != 'N'
chkmmdims(B,C,n,k,m,n)
elseif ctransa != 'N' && transb == 'N'
chkmmdims(B,C,m,n,k,n)
elseif ctransa != 'N' && transb != 'N'
chkmmdims(B,C,n,m,k,n)
end

cusparseSpMM(handle(), ctransa, transb, T[alpha], CuSparseMatrixDescriptor(A),
CuDenseMatrixDescriptor(B), T[beta], CuDenseMatrixDescriptor(C), T,
CUSPARSE_MM_ALG_DEFAULT, CU_NULL)

C
end
28 changes: 14 additions & 14 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}}, B::CuMatrix{T}) where {T<
Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')
Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::CuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')

LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector) where {T} = mv!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T}) where T = mv!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T}) where {T} = mv!('C',one(T),parent(adjA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::CuVector,alpha::Number,beta::Number) where {T} = mv!('N',alpha,A,B,beta,C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::CuVector,alpha::Number,beta::Number) where {T} = mv!('T',alpha,parent(transA),B,beta,C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::CuVector,alpha::Number,beta::Number) where {T} = mv!('C',alpha,parent(adjA),B,beta,C,'O')
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::CuVector{T},alpha::Number,beta::Number) where T = mv!('N',alpha,A,B,beta,C,'O')
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T},alpha::Number,beta::Number) where {T} = mv!('T',alpha,parent(transA),B,beta,C,'O')
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::CuVector{T},alpha::Number,beta::Number) where {T} = mv!('C',alpha,parent(adjA),B,beta,C,'O')

LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},B::CuMatrix{T}) where {T} = mm2!('N','N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, <:CuMatrix{T}}) where {T} = mm2!('N','T',one(T),A,parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T}) where {T} = mm2!('T','N',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, <:CuMatrix{T}}) where {T} = mm2!('T','T',one(T),parent(transA),parent(transB),zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T}) where {T} = mm2!('C','N',one(T),parent(adjA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},B::CuMatrix{T},alpha::Number,beta::Number) where {T} = mm!('N','N',alpha,A,B,beta,C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, <:CuMatrix{T}},alpha::Number,beta::Number) where {T} = mm!('N','T',alpha,A,parent(transB),beta,C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T},alpha::Number,beta::Number) where {T} = mm!('T','N',alpha,parent(transA),B,beta,C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, <:CuMatrix{T}},alpha::Number,beta::Number) where {T} = mm!('T','T',alpha,parent(transA),parent(transB),beta,C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:CuSparseMatrix{T}},B::CuMatrix{T},alpha::Number,beta::Number) where {T} = mm!('C','N',alpha,parent(adjA),B,beta,C,'O')

LinearAlgebra.mul!(C::CuMatrix{T},A::HermOrSym{<:Number, <:CuSparseMatrix},B::CuMatrix) where {T} = mm!('N',one(T),A,B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix) where {T} = mm!('T',one(T),parent(transA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix) where {T} = mm!('C',one(T),parent(adjA),B,zero(T),C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},A::HermOrSym{<:Number, <:CuSparseMatrix},B::CuMatrix,alpha::Number,beta::Number) where {T} = mm!('N',alpha,A,B,beta,C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix,alpha::Number,beta::Number) where {T} = mm!('T',alpha,parent(transA),B,beta,C,'O')
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::CuMatrix,alpha::Number,beta::Number) where {T} = mm!('C',alpha,parent(adjA),B,beta,C,'O')

Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::CuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O')
Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}},B::CuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O')
Expand Down
8 changes: 4 additions & 4 deletions lib/cusparse/level2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ for (fname,elty) in ((:cusparseSbsrmv, :Float32),
(:cusparseZbsrmv, :ComplexF64))
@eval begin
function mv!(transa::SparseChar,
alpha::$elty,
alpha::Number,
A::CuSparseMatrixBSR{$elty},
X::CuVector{$elty},
beta::$elty,
beta::Number,
Y::CuVector{$elty},
index::SparseChar)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, index)
Expand All @@ -36,8 +36,8 @@ for (fname,elty) in ((:cusparseSbsrmv, :Float32),
chkmvdims(X,m,Y,n)
end
$fname(handle(), A.dir, transa, mb, nb,
A.nnz, [alpha], desc, A.nzVal, A.rowPtr,
A.colVal, A.blockDim, X, [beta], Y)
A.nnz, $elty[alpha], desc, A.nzVal, A.rowPtr,
A.colVal, A.blockDim, X, $elty[beta], Y)
Y
end
end
Expand Down
Loading

0 comments on commit e113713

Please sign in to comment.