Skip to content

Commit

Permalink
Fix the type instability causing slowdown and extra memory
Browse files Browse the repository at this point in the history
allocation in sparse vcat. Fixes #7926.
Add a test for vcat of sparse matrices of different element/index types.
  • Loading branch information
ViralBShah committed Feb 15, 2015
1 parent 07f3ee7 commit 3e1d2f2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
49 changes: 31 additions & 18 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1884,48 +1884,61 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end



# Sparse concatenation

function vcat(X::SparseMatrixCSC...)
Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

vcat(map(x->convert(SparseMatrixCSC{Tv,Ti}, x), X)...)
end

function vcat{Tv,Ti<:Integer}(X::SparseMatrixCSC{Tv,Ti}...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
m = sum(mX)
n = nX[1]

for i = 2 : num
if nX[i] != n; throw(DimensionMismatch("")); end
if nX[i] != n
throw(DimensionMismatch("All inputs to vcat should have the same number of columns"))
end
end
m = sum(mX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array(Ti, n + 1)
nnzX = [ nnz(x) for x in X ]
nnz_res = sum(nnzX)
colptr = Array(Ti, n + 1)
rowval = Array(Ti, nnz_res)
nzval = Array(Tv, nnz_res)
nzval = Array(Tv, nnz_res)

colptr[1] = 1
@inbounds for c = 1 : n
for c = 1:n
mX_sofar = 0
rr1 = colptr[c]
ptr_res = colptr[c]
for i = 1 : num
XI = X[i]
rX1 = XI.colptr[c]
rX2 = XI.colptr[c + 1] - 1
rr2 = rr1 + (rX2 - rX1)
Xi = X[i]
colptrXi = Xi.colptr
rowvalXi = Xi.rowval
nzvalXi = Xi.nzval

col_length = (colptrXi[c + 1] - 1) - colptrXi[c]
ptrXi = colptrXi[c]
for k=ptr_res:(ptr_res + col_length)
@inbounds rowval[k] = rowvalXi[ptrXi] + mX_sofar
@inbounds nzval[k] = nzvalXi[ptrXi]
ptrXi += 1
end

rowval[rr1 : rr2] = XI.rowval[rX1 : rX2] .+ mX_sofar
nzval[rr1 : rr2] = XI.nzval[rX1 : rX2]
ptr_res += col_length + 1
mX_sofar += mX[i]
rr1 = rr2 + 1
end
colptr[c + 1] = rr1
colptr[c + 1] = ptr_res
end
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end


function hcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
Expand Down
2 changes: 2 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ do33 = ones(3)

# check vert concatenation
@test all([se33; se33] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6)))
se33_32bit = convert(SparseMatrixCSC{Float32,Int32}, se33)
@test all([se33; se33_32bit] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6)))

# check h+v concatenation
se44 = speye(4)
Expand Down

0 comments on commit 3e1d2f2

Please sign in to comment.