Skip to content

Commit

Permalink
use sparse multiplication for sparse arrays time BitMatrix and BitVec…
Browse files Browse the repository at this point in the history
…tor and wrappers (JuliaLang#39557)
  • Loading branch information
abraunst authored and ElOceanografo committed May 4, 2021
1 parent 54dedc1 commit 36c38ec
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import LinearAlgebra: checksquare, sym_uplo
using Random: rand!

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
const StridedOrTriangularMatrix{T} = Union{StridedMatrix{T}, LowerTriangular{T}, UnitLowerTriangular{T}, UpperTriangular{T}, UnitUpperTriangular{T}}
const AdjOrTransStridedOrTriangularMatrix{T} = Union{StridedOrTriangularMatrix{T},Adjoint{<:Any,<:StridedOrTriangularMatrix{T}},Transpose{<:Any,<:StridedOrTriangularMatrix{T}}}
const DenseMatrixUnion = Union{StridedMatrix, LowerTriangular, UnitLowerTriangular, UpperTriangular, UnitUpperTriangular, BitMatrix}
const AdjOrTransDenseMatrix = Union{DenseMatrixUnion,Adjoint{<:Any,<:DenseMatrixUnion},Transpose{<:Any,<:DenseMatrixUnion}}
const DenseInputVector = Union{StridedVector, BitVector}
const DenseInputVecOrMat = Union{AdjOrTransDenseMatrix, DenseInputVector}

for op (:+, :-), Wrapper (:Hermitian, :Symmetric)
@eval begin
Expand All @@ -25,7 +27,7 @@ for op ∈ (:+, :-)
end
end

function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::DenseInputVecOrMat, α::Number, β::Number)
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
Expand All @@ -44,13 +46,14 @@ function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVe
end
C
end
*(A::SparseMatrixCSCUnion{TA}, x::StridedVector{Tx}) where {TA,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(A, 1)), A, x, true, false))
*(A::SparseMatrixCSCUnion{TA}, B::AdjOrTransStridedOrTriangularMatrix{Tx}) where {TA,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false))

*(A::SparseMatrixCSCUnion{TA}, x::DenseInputVector) where {TA} =
(T = promote_op(matprod, TA, eltype(x)); mul!(similar(x, T, size(A, 1)), A, x, true, false))
*(A::SparseMatrixCSCUnion{TA}, B::AdjOrTransDenseMatrix) where {TA} =
(T = promote_op(matprod, TA, eltype(B)); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false))

for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
@eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
@eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, B::DenseInputVecOrMat, α::Number, β::Number)
A = xA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
Expand All @@ -72,16 +75,16 @@ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
C
end
end
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(adjA), Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::DenseInputVector) =
(T = promote_op(matprod, eltype(adjA), eltype(x)); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransDenseMatrix) =
(T = promote_op(matprod, eltype(adjA), eltype(B)); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, true, false))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(transA), Tx); mul!(similar(x, T, size(transA, 1)), transA, x, true, false))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::DenseInputVector) =
(T = promote_op(matprod, eltype(transA), eltype(x)); mul!(similar(x, T, size(transA, 1)), transA, x, true, false))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransDenseMatrix) =
(T = promote_op(matprod, eltype(transA), eltype(B)); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, true, false))

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) || throw(DimensionMismatch())
mX == size(C, 1) || throw(DimensionMismatch())
Expand All @@ -91,7 +94,7 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Abs
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
if X isa StridedOrTriangularMatrix
if X isa DenseMatrixUnion
@inbounds for col in 1:size(A, 2), k in nzrange(A, col)
Aiα = nzv[k] * α
rvk = rv[k]
Expand All @@ -108,11 +111,11 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Abs
end
C
end
*(X::AdjOrTransStridedOrTriangularMatrix, A::SparseMatrixCSCUnion{TvA}) where {TvA} =
*(X::AdjOrTransDenseMatrix, A::SparseMatrixCSCUnion{TvA}) where {TvA} =
(T = promote_op(matprod, eltype(X), TvA); mul!(similar(X, T, (size(X, 1), size(A, 2))), X, A, true, false))

for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
@eval function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
@eval function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = xA.parent
mX, nX = size(X)
nX == size(A, 2) || throw(DimensionMismatch())
Expand All @@ -133,9 +136,9 @@ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
C
end
end
*(X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
*(X::AdjOrTransDenseMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(adjA)); mul!(similar(X, T, (size(X, 1), size(adjA, 2))), X, adjA, true, false))
*(X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
*(X::AdjOrTransDenseMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(transA)); mul!(similar(X, T, (size(X, 1), size(transA, 2))), X, transA, true, false))

function (*)(D::Diagonal, A::AbstractSparseMatrixCSC)
Expand Down

0 comments on commit 36c38ec

Please sign in to comment.