diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index 25ddb3222f6b2..1ed2c680059ce 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -4,8 +4,10 @@ import LinearAlgebra: checksquare, sym_uplo using Random: rand! # In matrix-vector multiplication, the correct orientation of the vector is assumed. -const StridedOrTriangularMatrix{T} = Union{StridedMatrix{T}, LowerTriangular{T}, UnitLowerTriangular{T}, UpperTriangular{T}, UnitUpperTriangular{T}} -const AdjOrTransStridedOrTriangularMatrix{T} = Union{StridedOrTriangularMatrix{T},Adjoint{<:Any,<:StridedOrTriangularMatrix{T}},Transpose{<:Any,<:StridedOrTriangularMatrix{T}}} +const DenseMatrixUnion = Union{StridedMatrix, LowerTriangular, UnitLowerTriangular, UpperTriangular, UnitUpperTriangular, BitMatrix} +const AdjOrTransDenseMatrix = Union{DenseMatrixUnion,Adjoint{<:Any,<:DenseMatrixUnion},Transpose{<:Any,<:DenseMatrixUnion}} +const DenseInputVector = Union{StridedVector, BitVector} +const DenseInputVecOrMat = Union{AdjOrTransDenseMatrix, DenseInputVector} for op ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric) @eval begin @@ -25,7 +27,7 @@ for op ∈ (:+, :-) end end -function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number) +function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::DenseInputVecOrMat, α::Number, β::Number) size(A, 2) == size(B, 1) || throw(DimensionMismatch()) size(A, 1) == size(C, 1) || throw(DimensionMismatch()) size(B, 2) == size(C, 2) || throw(DimensionMismatch()) @@ -44,13 +46,14 @@ function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVe end C end -*(A::SparseMatrixCSCUnion{TA}, x::StridedVector{Tx}) where {TA,Tx} = - (T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(A, 1)), A, x, true, false)) -*(A::SparseMatrixCSCUnion{TA}, B::AdjOrTransStridedOrTriangularMatrix{Tx}) where {TA,Tx} = - (T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false)) + +*(A::SparseMatrixCSCUnion{TA}, x::DenseInputVector) where {TA} = + (T = promote_op(matprod, TA, eltype(x)); mul!(similar(x, T, size(A, 1)), A, x, true, false)) +*(A::SparseMatrixCSCUnion{TA}, B::AdjOrTransDenseMatrix) where {TA} = + (T = promote_op(matprod, TA, eltype(B)); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false)) for (T, t) in ((Adjoint, adjoint), (Transpose, transpose)) - @eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number) + @eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, B::DenseInputVecOrMat, α::Number, β::Number) A = xA.parent size(A, 2) == size(C, 1) || throw(DimensionMismatch()) size(A, 1) == size(B, 1) || throw(DimensionMismatch()) @@ -72,16 +75,16 @@ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose)) C end end -*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} = - (T = promote_op(matprod, eltype(adjA), Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false)) -*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) = +*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::DenseInputVector) = + (T = promote_op(matprod, eltype(adjA), eltype(x)); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false)) +*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransDenseMatrix) = (T = promote_op(matprod, eltype(adjA), eltype(B)); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, true, false)) -*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} = - (T = promote_op(matprod, eltype(transA), Tx); mul!(similar(x, T, size(transA, 1)), transA, x, true, false)) -*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) = +*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::DenseInputVector) = + (T = promote_op(matprod, eltype(transA), eltype(x)); mul!(similar(x, T, size(transA, 1)), transA, x, true, false)) +*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransDenseMatrix) = (T = promote_op(matprod, eltype(transA), eltype(B)); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, true, false)) -function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number) +function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number) mX, nX = size(X) nX == size(A, 1) || throw(DimensionMismatch()) mX == size(C, 1) || throw(DimensionMismatch()) @@ -91,7 +94,7 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Abs if β != 1 β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) end - if X isa StridedOrTriangularMatrix + if X isa DenseMatrixUnion @inbounds for col in 1:size(A, 2), k in nzrange(A, col) Aiα = nzv[k] * α rvk = rv[k] @@ -108,11 +111,11 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Abs end C end -*(X::AdjOrTransStridedOrTriangularMatrix, A::SparseMatrixCSCUnion{TvA}) where {TvA} = +*(X::AdjOrTransDenseMatrix, A::SparseMatrixCSCUnion{TvA}) where {TvA} = (T = promote_op(matprod, eltype(X), TvA); mul!(similar(X, T, (size(X, 1), size(A, 2))), X, A, true, false)) for (T, t) in ((Adjoint, adjoint), (Transpose, transpose)) - @eval function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number) + @eval function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number) A = xA.parent mX, nX = size(X) nX == size(A, 2) || throw(DimensionMismatch()) @@ -133,9 +136,9 @@ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose)) C end end -*(X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) = +*(X::AdjOrTransDenseMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) = (T = promote_op(matprod, eltype(X), eltype(adjA)); mul!(similar(X, T, (size(X, 1), size(adjA, 2))), X, adjA, true, false)) -*(X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) = +*(X::AdjOrTransDenseMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) = (T = promote_op(matprod, eltype(X), eltype(transA)); mul!(similar(X, T, (size(X, 1), size(transA, 2))), X, transA, true, false)) function (*)(D::Diagonal, A::AbstractSparseMatrixCSC)