Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractSparseMatrixCSR interface baseline #546

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
SparseMatrixCSC, SparseVector, blockdiag, droptol!, dropzeros!, dropzeros,
issparse, nonzeros, nzrange, rowvals, sparse, sparsevec, spdiagm,
sprand, sprandn, spzeros, nnz, permute, findnz, fkeep!, ftranspose!,
sparse_hcat, sparse_vcat, sparse_hvcat
sparse_hcat, sparse_vcat, sparse_hvcat, colvals

const LinAlgLeftQs = Union{HessenbergQ,QRCompactWYQ,QRPackedQ}

Expand Down
8 changes: 8 additions & 0 deletions src/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ Supertype for matrix with compressed sparse column (CSC).
"""
abstract type AbstractSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti} end

"""
AbstractSparseMatrixCSR{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}

Supertype for matrix with compressed sparse row (CSR).
"""
abstract type AbstractSparseMatrixCSR{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti} end

const AbstractSparseMatrixCSCOrCSR{Tv,Ti} = Union{AbstractSparseMatrixCSR{Tv,Ti}, AbstractSparseMatrixCSC{Tv,Ti}}

"""
issparse(S)
Expand Down
246 changes: 239 additions & 7 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
const DenseInputVector = Union{StridedVector, BitVector}
const DenseVecOrMat = Union{DenseMatrixUnion, DenseInputVector}

# CSC
matprod_dest(A::SparseMatrixCSCUnion2, B::DenseTriangular, TS) =
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, B::DenseTriangular, TS) =
Expand All @@ -29,6 +30,24 @@
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))

# CSR
matprod_dest(A::SparseMatrixCSRUnion2, B::DenseTriangular, TS) =

Check warning on line 34 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L34

Added line #L34 was not covered by tests
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::AdjOrTrans{<:Any,<:SparseMatrixCSRUnion2}, B::DenseTriangular, TS) =

Check warning on line 36 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L36

Added line #L36 was not covered by tests
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSRUnion2, TS) =

Check warning on line 38 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L38

Added line #L38 was not covered by tests
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSRUnion2, TS) =

Check warning on line 40 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L40

Added line #L40 was not covered by tests
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::SparseMatrixCSRUnion2, TS) =

Check warning on line 42 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L42

Added line #L42 was not covered by tests
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:SparseMatrixCSRUnion2}, TS) =

Check warning on line 44 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L44

Added line #L44 was not covered by tests
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:SparseMatrixCSRUnion2}, TS) =

Check warning on line 46 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L46

Added line #L46 was not covered by tests
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:SparseMatrixCSRUnion2}, TS) =

Check warning on line 48 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L48

Added line #L48 was not covered by tests
similar(A, TS, (size(A, 1), size(B, 2)))

for op ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric)
@eval begin
$op(A::AbstractSparseMatrix, B::$Wrapper{<:Any,<:AbstractSparseMatrix}) = $op(A, sparse(B))
Expand All @@ -54,6 +73,13 @@
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, alpha::Number, beta::Number) =
spdensemul!(C, tA, 'N', A, B, alpha, beta)

generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSRUnion2, B::DenseMatrixUnion, alpha::Number, beta::Number) =

Check warning on line 76 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L76

Added line #L76 was not covered by tests
spdensemul!(C, tA, tB, A, B, alpha, beta)
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSRUnion2, B::AbstractTriangular, alpha::Number, beta::Number) =

Check warning on line 78 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L78

Added line #L78 was not covered by tests
spdensemul!(C, tA, tB, A, B, alpha, beta)
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSRUnion2, B::DenseInputVector, alpha::Number, beta::Number) =

Check warning on line 80 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L80

Added line #L80 was not covered by tests
spdensemul!(C, tA, 'N', A, B, alpha, beta)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
if tA_uc == 'N'
Expand All @@ -74,7 +100,7 @@
return C
end

function _spmatmul!(C, A, B, α, β)
function _spmatmul!(C, A::AbstractSparseMatrixCSC, B, α, β)
size(A, 2) == size(B, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))"))
size(A, 1) == size(C, 1) ||
Expand All @@ -95,7 +121,28 @@
C
end

function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
function _spmatmul!(C, A::AbstractSparseMatrixCSR, B, α, β)
size(A, 2) == size(B, 1) ||

Check warning on line 125 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))"))
size(A, 1) == size(C, 1) ||

Check warning on line 127 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L127

Added line #L127 was not covered by tests
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))"))
size(B, 2) == size(C, 2) ||

Check warning on line 129 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L129

Added line #L129 was not covered by tests
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
nzv = nonzeros(A)
cv = colvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for k in 1:size(C, 1)
@inbounds for row in 1:size(A, 1)
αxj = B[k, row] * α
for j in nzrange(A, row)
C[k, cv[j]] += nzv[j]*αxj
end
end
end
C

Check warning on line 142 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L131-L142

Added lines #L131 - L142 were not covered by tests
end

function _At_or_Ac_mul_B!(tfun::Function, C, A::AbstractSparseMatrixCSC, B, α, β)
size(A, 2) == size(C, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))"))
size(A, 1) == size(B, 1) ||
Expand All @@ -117,6 +164,28 @@
C
end

function _At_or_Ac_mul_B!(tfun::Function, C, A::AbstractSparseMatrixCSR, B, α, β)
size(A, 2) == size(C, 1) ||

Check warning on line 168 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L167-L168

Added lines #L167 - L168 were not covered by tests
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))"))
size(A, 1) == size(B, 1) ||

Check warning on line 170 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L170

Added line #L170 was not covered by tests
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))"))
size(B, 2) == size(C, 2) ||

Check warning on line 172 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L172

Added line #L172 was not covered by tests
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
nzv = nonzeros(A)
cv = colvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for k in 1:size(C, 1)
@inbounds for row in 1:size(A, 1)
tmp = zero(eltype(C))
for j in nzrange(A, row)
tmp += tfun(nzv[j])*B[k,cv[j]]
end
C[k,row] += tmp * α
end
end
C

Check warning on line 186 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L174-L186

Added lines #L174 - L186 were not covered by tests
end

Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
Expand Down Expand Up @@ -167,6 +236,45 @@
C
end

function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSRUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||

Check warning on line 241 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L239-L241

Added lines #L239 - L241 were not covered by tests
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||

Check warning on line 243 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L243

Added line #L243 was not covered by tests
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||

Check warning on line 245 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L245

Added line #L245 was not covered by tests
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
cv = colvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for row in 1:size(A, 1), k in nzrange(A, row)
Aiα = nzv[k] * α
cvk = cv[k]
@simd for multivec_col in 1:nX
C[row, multivec_col] += X[cvk, multivec_col] * Aiα
end
end
C

Check warning on line 257 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L247-L257

Added lines #L247 - L257 were not covered by tests
end
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSRUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||

Check warning on line 261 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L259-L261

Added lines #L259 - L261 were not covered by tests
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||

Check warning on line 263 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L263

Added line #L263 was not covered by tests
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||

Check warning on line 265 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L265

Added line #L265 was not covered by tests
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
cv = colvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for multivec_col in 1:nX, row in 1:size(A, 1)
@inbounds for k in nzrange(A, row)
C[row, multivec_col] += X[cv[k], multivec_col] * nzv[k] * α
end
end
C

Check warning on line 275 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L267-L275

Added lines #L267 - L275 were not covered by tests
end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) ||
Expand All @@ -188,22 +296,53 @@
C
end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSRUnion2, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) ||

Check warning on line 301 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L299-L301

Added lines #L299 - L301 were not covered by tests
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
mA == size(C, 1) ||

Check warning on line 303 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L303

Added line #L303 was not covered by tests
throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))"))
size(B, 1) == size(C, 2) ||

Check warning on line 305 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L305

Added line #L305 was not covered by tests
throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
cv = colvals(B)
nzv = nonzeros(B)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for row in 1:size(B, 1), k in nzrange(B, row)
Biα = tfun(nzv[k]) * α
cvk = cv[k]
@simd for multivec_row in 1:nA
C[cvk, multivec_row] += A[row, multivec_row] * Biα
end
end
C

Check warning on line 317 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L307-L317

Added lines #L307 - L317 were not covered by tests
end

# Sparse matrix multiplication as described in [Gustavson, 1978]:
# http://dl.acm.org/citation.cfm?id=355796

const SparseTriangular{Tv,Ti} = Union{UpperTriangular{Tv,<:SparseMatrixCSCUnion{Tv,Ti}},LowerTriangular{Tv,<:SparseMatrixCSCUnion{Tv,Ti}}}
const SparseOrTri{Tv,Ti} = Union{SparseMatrixCSCUnion{Tv,Ti},SparseTriangular{Tv,Ti}}
const SparseTriangularCSC{Tv,Ti} = Union{UpperTriangular{Tv,<:SparseMatrixCSCUnion{Tv,Ti}},LowerTriangular{Tv,<:SparseMatrixCSCUnion{Tv,Ti}}}
const SparseTriangularCSR{Tv,Ti} = Union{UpperTriangular{Tv,<:SparseMatrixCSRUnion{Tv,Ti}},LowerTriangular{Tv,<:SparseMatrixCSRUnion{Tv,Ti}}}
const SparseTriangular{Tv,Ti} = Union{SparseTriangularCSC{Tv,Ti}, SparseTriangularCSR{Tv,Ti}}
const SparseOrTriCSC{Tv,Ti} = Union{SparseMatrixCSCUnion{Tv,Ti},SparseTriangularCSC{Tv,Ti}}
const SparseOrTriCSR{Tv,Ti} = Union{SparseMatrixCSRUnion{Tv,Ti},SparseTriangularCSR{Tv,Ti}}
const SparseOrTri{Tv,Ti} = Union{SparseOrTriCSC{Tv,Ti}, SparseOrTriCSR{Tv,Ti}}

*(A::SparseOrTri, B::AbstractSparseVector) = spmatmulv(A, B)
*(A::SparseOrTri, B::SparseColumnView) = spmatmulv(A, B)
*(A::SparseOrTri, B::SparseVectorView) = spmatmulv(A, B)
*(A::SparseMatrixCSCUnion, B::SparseMatrixCSCUnion) = spmatmul(A,B)
*(A::SparseMatrixCSRUnion, B::SparseMatrixCSRUnion) = spmatmul(A,B)
*(A::SparseTriangular, B::SparseMatrixCSCUnion) = spmatmul(A,B)
*(A::SparseTriangular, B::SparseMatrixCSRUnion) = spmatmul(A,B)

Check warning on line 336 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L336

Added line #L336 was not covered by tests
*(A::SparseMatrixCSCUnion, B::SparseTriangular) = spmatmul(A,B)
*(A::SparseMatrixCSRUnion, B::SparseTriangular) = spmatmul(A,B)

Check warning on line 338 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L338

Added line #L338 was not covered by tests
*(A::SparseTriangular, B::SparseTriangular) = spmatmul1(A,B)
*(A::SparseOrTri, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(A, copy(B))
*(A::SparseOrTri, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSR}) = spmatmul(A, copy(B))
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::SparseOrTri) = spmatmul(copy(A), B)
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSR}, B::SparseOrTri) = spmatmul(copy(A), B)
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(copy(A), copy(B))
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSR}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSR}) = spmatmul(copy(A), copy(B))

Check warning on line 345 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L345

Added line #L345 was not covered by tests

# Gustavson's matrix multiplication algorithm revisited.
# The result rowval vector is already sorted by construction.
Expand All @@ -213,7 +352,7 @@
# done by a quicksort of the row indices or by a full scan of the dense result vector.
# The last is faster, if more than ≈ 1/32 of the result column is nonzero.
# TODO: extend to SparseMatrixCSCUnion to allow for SubArrays (view(X, :, r)).
function spmatmul(A::SparseOrTri, B::Union{SparseOrTri,AbstractCompressedVector,SubArray{<:Any,<:Any,<:AbstractSparseArray}})
function spmatmul(A::SparseOrTriCSC, B::Union{SparseOrTriCSC,AbstractCompressedVector,SubArray{<:Any,<:Any,<:AbstractSparseArray}})
Tv = promote_op(matprod, eltype(A), eltype(B))
Ti = promote_type(indtype(A), indtype(B))
mA, nA = size(A)
Expand Down Expand Up @@ -248,6 +387,41 @@
C = SparseMatrixCSC(mA, nB, colptrC, rowvalC, nzvalC)
return C
end
function spmatmul(A::MatrixType, B::Union{MatrixType,AbstractCompressedVector,SubArray{<:Any,<:Any,<:AbstractSparseArray}}) where MatrixType <: AbstractSparseMatrixCSR
Tv = promote_op(matprod, eltype(A), eltype(B))
Ti = promote_type(indtype(A), indtype(B))
mA, nA = size(A)
nB = size(B, 2)
mB = size(B, 1)
nA == mB || throw(DimensionMismatch("second dimension of A, $nA, does not match the first dimension of B, $mB"))

nnzC = min(estimate_mulsize(mA, nnz(A), nA, nnz(B), nB) * 11 ÷ 10 + mA, mA*nB)
rowptrC = Vector{Ti}(undef, mA+1)
colvalC = Vector{Ti}(undef, nnzC)
nzvalC = Vector{Tv}(undef, nnzC)

@inbounds begin
jp = 1
xb = fill(false, nB)
for j in 1:mA
if jp + nB - 1 > nnzC
nnzC += max(nB, nnzC>>2)
resize!(colvalC, nnzC)
resize!(nzvalC, nnzC)

Check warning on line 410 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L408-L410

Added lines #L408 - L410 were not covered by tests
end
rowptrC[j] = jp
jp = sprowmul!(colvalC, nzvalC, xb, j, jp, A, B)
end
rowptrC[mA+1] = jp
end

resize!(colvalC, jp - 1)
resize!(nzvalC, jp - 1)

# This modification of Gustavson algorithm has sorted row indices
C = MatrixType(mA, nB, rowptrC, colvalC, nzvalC)
return C
end

# process single rhs column
function spcolmul!(rowvalC, nzvalC, xb, i, ip, A, B)
Expand Down Expand Up @@ -297,6 +471,54 @@
end
return ip
end
# process single rhs row
function sprowmul!(colvalC, nzvalC, xb, j, jp, A, B)
colvalA = colvals(A); nzvalA = nonzeros(A)
colvalB = colvals(B); nzvalB = nonzeros(B)
nB = size(B, 2)
jp0 = jp
k0 = jp - 1
@inbounds begin
for ip in nzrange(A, j)
nzA = nzvalA[ip]
i = colvalA[ip]
for kp in nzrange(B, i)
nzC = nzvalB[kp] * nzA
k = colvalB[kp]
if xb[k]
nzvalC[k+k0] += nzC
else
nzvalC[k+k0] = nzC
xb[k] = true
colvalC[jp] = k
jp += 1
end
end
end
if jp > jp0
if prefer_sort(jp-k0, nB)
# in-place sort of indices. Effort: O(nnz*ln(nnz)).
sort!(colvalC, jp0, jp-1, QuickSort, Base.Order.Forward)
for vp = jp0:jp-1
k = colvalC[vp]
xb[k] = false
nzvalC[vp] = nzvalC[k+k0]
end

Check warning on line 506 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L501-L506

Added lines #L501 - L506 were not covered by tests
else
# scan result vector (effort O(mA))
for k = 1:nB
if xb[k]
xb[k] = false
colvalC[jp0] = k
nzvalC[jp0] = nzvalC[k+k0]
jp0 += 1
end
end
end
end
end
return jp
end

# special cases of same twin Upper/LowerTriangular
spmatmul1(A, B) = spmatmul(A, B)
Expand Down Expand Up @@ -1104,17 +1326,27 @@
end

# row range up to (and including if excl=false) diagonal
function nzrangeup(A, i, excl=false)
function nzrangeup(A::SparseMatrixCSCUnion3, i, excl=false)
r = nzrange(A, i); r1 = r.start; r2 = r.stop
rv = rowvals(A)
@inbounds r2 < r1 || rv[r2] <= i - excl ? r : r1:(searchsortedlast(view(rv, r1:r2), i - excl) + r1-1)
end
# row range from diagonal (included if excl=false) to end
function nzrangelo(A, i, excl=false)
function nzrangelo(A::SparseMatrixCSCUnion3, i, excl=false)
r = nzrange(A, i); r1 = r.start; r2 = r.stop
rv = rowvals(A)
@inbounds r2 < r1 || rv[r1] >= i + excl ? r : (searchsortedfirst(view(rv, r1:r2), i + excl) + r1-1):r2
end
function nzrangeup(A::SparseMatrixCSRUnion3, i, excl=false)
c = nzrange(A, i); c1 = c.start; c2 = c.stop
cv = colvals(A)
@inbounds c2 < c1 || cv[c1] >= i + excl ? c : (searchsortedfirst(view(cv, c1:c2), i + excl) + c1-1):c2

Check warning on line 1343 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L1340-L1343

Added lines #L1340 - L1343 were not covered by tests
end
function nzrangelo(A::SparseMatrixCSRUnion3, i, excl=false)
c = nzrange(A, i); c1 = c.start; c2 = c.stop
cv = colvals(A)
@inbounds c2 < c1 || cv[c2] <= i - excl ? c : c1:(searchsortedlast(view(cv, c1:c2), i - excl) + c1-1)

Check warning on line 1348 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L1345-L1348

Added lines #L1345 - L1348 were not covered by tests
end

dot(x::AbstractVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractVector) =
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint)
Expand Down
Loading
Loading