Skip to content

Commit

Permalink
Provide specialized typed_*cat(::Type{SparseMatrixCSC}, ...)
Browse files Browse the repository at this point in the history
And make hcat/vcat/hvcat returning sparse matrices just call the
corresponding typed_*cat function.
  • Loading branch information
martinholters committed Jul 6, 2016
1 parent 2049f19 commit e262163
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 26 deletions.
2 changes: 1 addition & 1 deletion base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
hcat, hvcat, imag, indmax, ishermitian, kron, length, log, log1p, max, min,
maximum, minimum, norm, one, promote_eltype, real, reinterpret, reshape, rot180,
rotl90, rotr90, round, scale!, setindex!, similar, size, transpose, tril,
triu, vcat, vec, permute!
triu, vcat, vec, permute!, typed_hcat, typed_vcat, typed_hvcat

import Base.Broadcast: eltype_plus, broadcast_shape

Expand Down
54 changes: 29 additions & 25 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3075,7 +3075,10 @@ end

# Sparse concatenation

function vcat(X::SparseMatrixCSC...)
promote_indtype() = Base.Bottom
promote_indtype{Tv,Ti}(X::SparseMatrixCSC{Tv,Ti}, Xs::SparseMatrixCSC...) = promote_type(Ti, promote_indtype(Xs...))

function typed_vcat{Tv,Ti}(::Type{SparseMatrixCSC{Tv,Ti}}, X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
Expand All @@ -3088,13 +3091,6 @@ function vcat(X::SparseMatrixCSC...)
end
end

Tv = eltype(X[1].nzval)
Ti = eltype(X[1].rowval)
for i = 2:length(X)
Tv = promote_type(Tv, eltype(X[i].nzval))
Ti = promote_type(Ti, eltype(X[i].rowval))
end

nnzX = [ nnz(x) for x in X ]
nnz_res = sum(nnzX)
colptr = Array{Ti}(n + 1)
Expand Down Expand Up @@ -3135,7 +3131,7 @@ end
end


function hcat(X::SparseMatrixCSC...)
function typed_hcat{Tv,Ti}(::Type{SparseMatrixCSC{Tv,Ti}}, X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
Expand All @@ -3145,9 +3141,6 @@ function hcat(X::SparseMatrixCSC...)
end
n = sum(nX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array{Ti}(n + 1)
nnzX = [ nnz(x) for x in X ]
nnz_res = sum(nnzX)
Expand All @@ -3173,29 +3166,40 @@ function hcat(X::SparseMatrixCSC...)
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end

for d in (:h, :v)
let cat_name = Symbol(d, :cat), typed_cat_name = Symbol(:typed_, cat_name)
@eval begin
$(cat_name)(Xin::Union{Vector, Matrix, SparseMatrixCSC}...) = $(typed_cat_name)(SparseMatrixCSC, Xin...)

# Sparse/dense concatenation

function hcat(Xin::Union{Vector, Matrix, SparseMatrixCSC}...)
X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin]
hcat(X...)
$(typed_cat_name){T<:SparseMatrixCSC}(::Type{T}) = T(0, 0)
$(typed_cat_name){T<:SparseMatrixCSC}(::Type{T}, A::AbstractVecOrMat...) =
$(typed_cat_name)(T, map(hcat, A)...)
$(typed_cat_name){T<:SparseMatrixCSC}(::Type{T}, A::AbstractMatrix...) =
$(typed_cat_name)(T, map(SparseMatrixCSC, A)...)
$(typed_cat_name)(::Type{SparseMatrixCSC}, X::SparseMatrixCSC...) =
$(typed_cat_name)(SparseMatrixCSC{promote_eltype(X...)}, X...)
$(typed_cat_name){Tv}(::Type{SparseMatrixCSC{Tv}}, X::SparseMatrixCSC...) =
$(typed_cat_name)(SparseMatrixCSC{Tv,promote_indtype(X...)}, X...)
end
end
end

function vcat(Xin::Union{Vector, Matrix, SparseMatrixCSC}...)
X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin]
vcat(X...)
end
typed_vcat{T<:SparseMatrixCSC}(::Type{T}, A::AbstractVector...) = typed_vcat(T, map(hcat, A)...)

function hvcat(rows::Tuple{Vararg{Int}}, X::Union{Vector, Matrix, SparseMatrixCSC}...)
hvcat(rows::Tuple{Vararg{Int}}, X::Union{Vector, Matrix, SparseMatrixCSC}...) = typed_hvcat(SparseMatrixCSC, rows, X...)
typed_hvcat{T<:SparseMatrixCSC}(::Type{T}, rows::Tuple{Vararg{Int}}) = T(length(rows),0)
typed_hvcat{T<:SparseMatrixCSC}(::Type{T}, rows::Tuple{Vararg{Int}}, X::AbstractMatrix...) =
invoke(typed_hvcat, (Type{T}, Tuple{Vararg{Int}}, Vararg{AbstractVecOrMat}), T, rows, X...)
function typed_hvcat{T<:SparseMatrixCSC}(::Type{T}, rows::Tuple{Vararg{Int}}, X::AbstractVecOrMat...)
nbr = length(rows) # number of block rows

tmp_rows = Array{SparseMatrixCSC}(nbr)
tmp_rows = Array{T}(nbr)
k = 0
@inbounds for i = 1 : nbr
tmp_rows[i] = hcat(X[(1 : rows[i]) + k]...)
tmp_rows[i] = typed_hcat(T, X[(1 : rows[i]) + k]...)
k += rows[i]
end
vcat(tmp_rows...)
typed_vcat(T, tmp_rows...)
end

"""
Expand Down

0 comments on commit e262163

Please sign in to comment.