Skip to content

Commit

Permalink
Merge pull request #16371 from KristofferC/kc/sparse_mat_immut
Browse files Browse the repository at this point in the history
RFC: make SparseMatrixCSC immutable
  • Loading branch information
ViralBShah committed May 18, 2016
2 parents 36cfe3c + 8aa787e commit fa41010
Showing 1 changed file with 64 additions and 59 deletions.
123 changes: 64 additions & 59 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Assumes that row values in rowval for each column are sorted
# issorted(rowval[colptr[i]:(colptr[i+1]-1)]) == true

type SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
immutable SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
Expand Down Expand Up @@ -2276,22 +2276,25 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector

nnzA = nnz(A) + lenI * length(J)

colptr = A.colptr
rowvalA = rowval = A.rowval
nzvalA = nzval = A.nzval

rowidx = 1
nadd = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(nadd > 0) && (colptr[col] = colptr[col] + nadd)
rrange = nzrange(A, col)
if nadd > 0
A.colptr[col] = A.colptr[col] + nadd
end

if col in J
if isempty(rrange) # set new vals only
nincl = lenI
if nadd == 0
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
rowval = copy(rowvalA)
nzval = copy(nzvalA)
resize!(rowvalA, nnzA)
resize!(nzvalA, nnzA)
end
r = rowidx:(rowidx+nincl-1)
rowvalA[r] = I
Expand All @@ -2317,8 +2320,10 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
old_ptr += 1
else
if nadd == 0
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
rowval = copy(rowvalA)
nzval = copy(nzvalA)
resize!(rowvalA, nnzA)
resize!(nzvalA, nnzA)
end
nadd += 1
end
Expand All @@ -2331,8 +2336,10 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
if old_ptr > old_stop
if new_ptr <= new_stop
if nadd == 0
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
rowval = copy(rowvalA)
nzval = copy(nzvalA)
resize!(rowvalA, nnzA)
resize!(nzvalA, nnzA)
end
r = rowidx:(rowidx+(new_stop-new_ptr))
rowvalA[r] = I[new_ptr:new_stop]
Expand Down Expand Up @@ -2361,12 +2368,9 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
end

if nadd > 0
colptr[n+1] = rowidx
A.colptr[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.rowval = rowvalA
A.nzval = nzvalA
end
return A
end
Expand All @@ -2387,14 +2391,16 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
return A
end

colptr = A.colptr
rowval = rowvalA = A.rowval
nzval = nzvalA = A.nzval
rowidx = 1
ndel = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(ndel > 0) && (colptr[col] = colptr[col] - ndel)
rrange = nzrange(A, col)
if ndel > 0
A.colptr[col] = A.colptr[col] - ndel
end

if isempty(rrange) || !(col in J)
nincl = length(rrange)
if(ndel > 0) && !isempty(rrange)
Expand All @@ -2406,8 +2412,8 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
for ridx in rrange
if rowval[ridx] in I
if ndel == 0
rowvalA = copy(rowval)
nzvalA = copy(nzval)
rowval = copy(rowvalA)
nzval = copy(nzvalA)
end
ndel += 1
else
Expand All @@ -2422,12 +2428,9 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
end

if ndel > 0
colptr[n+1] = rowidx
A.colptr[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.rowval = rowvalA
A.nzval = nzvalA
end
return A
end
Expand Down Expand Up @@ -2480,11 +2483,14 @@ function setindex!{Tv,Ti,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixC
colptrB = B.colptr; rowvalB = B.rowval; nzvalB = B.nzval

nnzS = nnz(A) + nnz(B)
colptrS = Array(Ti, n+1)
rowvalS = Array(Ti, nnzS)
nzvalS = Array(Tv, nnzS)

colptrS[1] = 1
colptrS = copy(A.colptr)
rowvalS = copy(A.rowval)
nzvalS = copy(A.nzval)

resize!(rowvalA, nnzS)
resize!(nzvalA, nnzS)

colB = 1
asgn_col = J[colB]

Expand All @@ -2497,73 +2503,70 @@ function setindex!{Tv,Ti,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixC

# Copy column of A if it is not being assigned into
if colB > nJ || col != J[colB]
colptrS[col+1] = colptrS[col] + (colptrA[col+1]-colptrA[col])
colptrA[col+1] = colptrA[col] + (colptrS[col+1]-colptrS[col])

for k = colptrA[col]:colptrA[col+1]-1
rowvalS[ptrS] = rowvalA[k]
nzvalS[ptrS] = nzvalA[k]
for k = colptrS[col]:colptrS[col+1]-1
rowvalA[ptrS] = rowvalS[k]
nzvalA[ptrS] = nzvalS[k]
ptrS += 1
end
continue
end

ptrA::Int = colptrA[col]
stopA::Int = colptrA[col+1]
ptrA::Int = colptrS[col]
stopA::Int = colptrS[col+1]
ptrB::Int = colptrB[colB]
stopB::Int = colptrB[colB+1]

while ptrA < stopA && ptrB < stopB
rowA = rowvalA[ptrA]
rowA = rowvalS[ptrA]
rowB = I[rowvalB[ptrB]]
if rowA < rowB
if ~I_asgn[rowA]
rowvalS[ptrS] = rowA
nzvalS[ptrS] = nzvalA[ptrA]
rowvalA[ptrS] = rowA
nzvalA[ptrS] = nzvalS[ptrA]
ptrS += 1
end
ptrA += 1
elseif rowB < rowA
rowvalS[ptrS] = rowB
nzvalS[ptrS] = nzvalB[ptrB]
rowvalA[ptrS] = rowB
nzvalA[ptrS] = nzvalB[ptrB]
ptrS += 1
ptrB += 1
else
rowvalS[ptrS] = rowB
nzvalS[ptrS] = nzvalB[ptrB]
rowvalA[ptrS] = rowB
nzvalA[ptrS] = nzvalB[ptrB]
ptrS += 1
ptrB += 1
ptrA += 1
end
end

while ptrA < stopA
rowA = rowvalA[ptrA]
rowA = rowvalS[ptrA]
if ~I_asgn[rowA]
rowvalS[ptrS] = rowA
nzvalS[ptrS] = nzvalA[ptrA]
rowvalA[ptrS] = rowA
nzvalA[ptrS] = nzvalS[ptrA]
ptrS += 1
end
ptrA += 1
end

while ptrB < stopB
rowB = I[rowvalB[ptrB]]
rowvalS[ptrS] = rowB
nzvalS[ptrS] = nzvalB[ptrB]
rowvalA[ptrS] = rowB
nzvalA[ptrS] = nzvalB[ptrB]
ptrS += 1
ptrB += 1
end

colptrS[col+1] = ptrS
colptrA[col+1] = ptrS
colB += 1
end

deleteat!(rowvalS, colptrS[end]:length(rowvalS))
deleteat!(nzvalS, colptrS[end]:length(nzvalS))
deleteat!(rowvalA, colptrA[end]:length(rowvalA))
deleteat!(nzvalA, colptrA[end]:length(nzvalA))

A.colptr = colptrS
A.rowval = rowvalS
A.nzval = nzvalS
return A
end

Expand Down Expand Up @@ -2619,10 +2622,12 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})

if (mode > 1) && (nadd == 0) && (ndel == 0)
# copy storage to take changes
colptrB = copy(colptrA)
colptrA = copy(colptrB)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalA = copy(rowvalB)
nzvalA = copy(nzvalB)
resize!(rowvalB, length(rowvalA)+memreq)
resize!(nzvalB, length(rowvalA)+memreq)
end
if mode == 1
rowvalB[bidx] = row
Expand Down Expand Up @@ -2675,7 +2680,6 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
deleteat!(nzvalB, bidx:n)
deleteat!(rowvalB, bidx:n)
end
A.nzval = nzvalB; A.rowval = rowvalB; A.colptr = colptrB
end
A
end
Expand Down Expand Up @@ -2746,10 +2750,12 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto

if (mode > 1) && (nadd == 0) && (ndel == 0)
# copy storage to take changes
colptrB = copy(colptrA)
colptrA = copy(colptrB)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalA = copy(rowvalB)
nzvalA = copy(nzvalB)
resize!(rowvalB, length(rowvalA)+memreq)
resize!(nzvalB, length(rowvalA)+memreq)
end
if mode == 1
rowvalB[bidx] = row
Expand Down Expand Up @@ -2786,7 +2792,6 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
deleteat!(nzvalB, bidx:n)
deleteat!(rowvalB, bidx:n)
end
A.nzval = nzvalB; A.rowval = rowvalB; A.colptr = colptrB
end
A
end
Expand Down

0 comments on commit fa41010

Please sign in to comment.