Skip to content

Commit

Permalink
Avoid using temporary vars for sparse data arrays
Browse files Browse the repository at this point in the history
Just use the accessor functions rowvals, nonzeros, getcolptr whenever
needed instead. Especially in the sparse vector case avoid using findnz,
which creates a copy of the data. Use nonzeroinds and nonzeros instead.
  • Loading branch information
mjacobse committed Apr 21, 2024
1 parent deda6e8 commit b961a7f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 22 deletions.
20 changes: 7 additions & 13 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3908,12 +3908,10 @@ function vcat(X::AbstractSparseMatrixCSC...)
ptr_res = colptr[c]
for i = 1 : num
colptrXi = getcolptr(X[i])
rowvalXi = rowvals(X[i])
nzvalXi = nonzeros(X[i])
col_length = colptrXi[c + 1] - colptrXi[c]
ptr_Xi = colptrXi[c]

stuffcol!(rowval, nzval, ptr_res, rowvalXi, nzvalXi, ptr_Xi, col_length, mX_sofar)
stuffcol!(rowval, nzval, ptr_res, rowvals(X[i]), nonzeros(X[i]), ptr_Xi, col_length, mX_sofar)

ptr_res += col_length
mX_sofar += mX[i]
Expand Down Expand Up @@ -3974,23 +3972,19 @@ end
# Efficient repetition of sparse matrices

function Base.repeat(A::AbstractSparseMatrixCSC, m)
colptr_source = getcolptr(A)
rowval_source = rowvals(A)
nzval_source = nonzeros(A)

nnz_new = nnz(A) * m
colptr = similar(colptr_source, length(colptr_source))
rowval = similar(rowval_source, nnz_new)
nzval = similar(nzval_source, nnz_new)
colptr = similar(getcolptr(A), length(getcolptr(A)))
rowval = similar(rowvals(A), nnz_new)
nzval = similar(nonzeros(A), nnz_new)

colptr[1] = 1
for c = 1 : size(A, 2)
ptr_res = colptr[c]
ptr_source = colptr_source[c]
col_length = colptr_source[c + 1] - ptr_source
ptr_source = getcolptr(A)[c]
col_length = getcolptr(A)[c + 1] - ptr_source
for index_repetition = 0 : (m - 1)
row_offset = index_repetition * size(A, 1)
stuffcol!(rowval, nzval, ptr_res, rowval_source, nzval_source, ptr_source,
stuffcol!(rowval, nzval, ptr_res, rowvals(A), nonzeros(A), ptr_source,
col_length, row_offset)
ptr_res += col_length
end
Expand Down
15 changes: 6 additions & 9 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1314,14 +1314,13 @@ function Base.repeat(v::AbstractSparseVector, m)
nnz_source = nnz(v)
nnz_new = nnz_source * m

nzind_source, nzval_source = findnz(v)
nzind = similar(nzind_source, nnz_new)
nzval = similar(nzval_source, nnz_new)
nzind = similar(nonzeroinds(v), nnz_new)
nzval = similar(nonzeros(v), nnz_new)

ptr_res = 1
for index_repetition = 0:(m-1)
row_offset = index_repetition * length(v)
stuffcol!(nzind, nzval, ptr_res, nzind_source, nzval_source, 1, nnz_source, row_offset)
stuffcol!(nzind, nzval, ptr_res, nonzeroinds(v), nonzeros(v), 1, nnz_source, row_offset)
ptr_res += nnz_source
end
@assert ptr_res == nnz_new + 1
Expand All @@ -1331,11 +1330,9 @@ end

function Base.repeat(v::AbstractSparseVector, m, n)
w = repeat(v, m)
nzind_source, nzval_source = findnz(w)

colptr = Vector{eltype(nzind_source)}(1 .+ nnz(w) * (0:n))
rowval = repeat(nzind_source, n)
nzval = repeat(nzval_source, n)
colptr = Vector{eltype(nonzeroinds(w))}(1 .+ nnz(w) * (0:n))
rowval = repeat(nonzeroinds(w), n)
nzval = repeat(nonzeros(w), n)
SparseMatrixCSC(length(w), n, colptr, rowval, nzval)
end

Expand Down

0 comments on commit b961a7f

Please sign in to comment.