From cf83c396ca238f168b3ccc4767e4a1a51d553470 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Tue, 3 Feb 2015 01:04:27 +0100 Subject: [PATCH] more efficient cat --- base/abstractarray.jl | 131 +++++++++++++++++------------------------- test/arrayops.jl | 8 --- 2 files changed, 54 insertions(+), 85 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 5cb33ef0ddb77..91c1c6ecbb556 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -662,102 +662,79 @@ function vcat{T}(A::AbstractMatrix{T}...) end ## cat: general case -function cat(catdims, X...) - catdims = collect(catdims) - nargs = length(X) - ndimsX = Int[isa(a,AbstractArray) ? ndims(a) : 0 for a in X] - ndimsC = max(maximum(ndimsX), maximum(catdims)) - catsizes = zeros(Int,(nargs,length(catdims))) - dims2cat = zeros(Int,ndimsC) - for k = 1:length(catdims) - dims2cat[catdims[k]]=k - end - - typeC = isa(X[1],AbstractArray) ? eltype(X[1]) : typeof(X[1]) - dimsC = Int[d <= ndimsX[1] ? size(X[1],d) : 1 for d=1:ndimsC] - for k = 1:length(catdims) - catsizes[1,k] = dimsC[catdims[k]] - end - for i = 2:nargs - typeC = promote_type(typeC, isa(X[i], AbstractArray) ? eltype(X[i]) : typeof(X[i])) - for d = 1:ndimsC - currentdim = (d <= ndimsX[i] ? size(X[i],d) : 1) - if dims2cat[d]==0 - dimsC[d] == currentdim || throw(DimensionMismatch("mismatch in dimension $(d)")) - else - dimsC[d] += currentdim - catsizes[i,dims2cat[d]] = currentdim +function cat(catdim::Integer, X...) + catsize = 0 + ndimsC = catdim + typeC = Bottom + + for i=1:length(X) + if isa(X[i], AbstractArray) + if ndimsC < ndims(X[i]) + ndimsC = ndims(X[i]) end + catsize += size(X[i], catdim) + typeC = promote_type(typeC, eltype(X[i])) + else + catsize += 1 + typeC = promote_type(typeC, typeof(X[i])) end end - C = similar(isa(X[1],AbstractArray) ? full(X[1]) : [X[1]], typeC, tuple(dimsC...)) - if length(catdims)>1 - fill!(C,0) - end - - offsets = zeros(Int,length(catdims)) - for i=1:nargs - cat_one = [ dims2cat[d]==0 ? (1:dimsC[d]) : (offsets[dims2cat[d]]+(1:catsizes[i,dims2cat[d]])) for d=1:ndimsC] - C[cat_one...] = X[i] - for k = 1:length(catdims) - offsets[k] += catsizes[i,k] - end + if isa(X[1],AbstractArray) + dimsC = ntuple(d->(d==catdim ? catsize : size(X[1],d)), ndimsC) + C = similar(full(X[1]), typeC, dimsC) + else + dimsC = ntuple(d->(d==catdim ? catsize : 1), ndimsC) + C = Array(typeC, dimsC) end - return C + cat!(C, catdim, X...) end vcat(X...) = cat(1, X...) hcat(X...) = cat(2, X...) -cat{T}(catdims, A::AbstractArray{T}...) = cat_t(catdims, T, A...) - -cat(catdims, A::AbstractArray...) = - cat_t(catdims, promote_eltype(A...), A...) +function cat{T}(catdim::Integer, X::AbstractArray{T}...) + catsize = 0 + ndimsC = catdim -function cat_t(catdims, typeC, A::AbstractArray...) - catdims = collect(catdims) - nargs = length(A) - ndimsA = Int[ndims(a) for a in A] - ndimsC = max(maximum(ndimsA), maximum(catdims)) - catsizes = zeros(Int,(nargs,length(catdims))) - dims2cat = zeros(Int,ndimsC) - for k = 1:length(catdims) - dims2cat[catdims[k]]=k - end - - dimsC = Int[d <= ndimsA[1] ? size(A[1],d) : 1 for d=1:ndimsC] - for k = 1:length(catdims) - catsizes[1,k] = dimsC[catdims[k]] - end - for i = 2:nargs - for d = 1:ndimsC - currentdim = (d <= ndimsA[i] ? size(A[i],d) : 1) - if dims2cat[d]==0 - dimsC[d] == currentdim || throw(DimensionMismatch("mismatch in dimension $(d)")) - else - dimsC[d] += currentdim - catsizes[i,dims2cat[d]] = currentdim - end + for i=1:length(X) + if ndimsC < ndims(X[i]) + ndimsC = ndims(X[i]) end + catsize += size(X[i], catdim) end - C = similar(full(A[1]), typeC, tuple(dimsC...)) - if length(catdims)>1 - fill!(C,0) - end + dimsC = ntuple(d->(d==catdim ? catsize : size(X[1],d)), ndimsC) + C = similar(full(X[1]), T, dimsC) + cat!(C, catdim, X...) +end - offsets = zeros(Int,length(catdims)) - for i=1:nargs - cat_one = [ dims2cat[d]==0 ? (1:dimsC[d]) : (offsets[dims2cat[d]]+(1:catsizes[i,dims2cat[d]])) for d=1:ndimsC] - C[cat_one...] = A[i] - for k = 1:length(catdims) - offsets[k] += catsizes[i,k] - end +function cat!(C::AbstractArray, catdim::Integer, X...) + ndimsC = ndims(C) + + offset = 0 + for i=1:length(X) + catsize = isa(X[i], AbstractArray) ? size(X[i], catdim) : 1 + catrange = offset + (1:catsize) + cat_insert!(C,catdim,catrange,X[i]) + offset += catsize end + + size(C,catdim) == offset || throw(DimensionMismatch("mismatch in dimension $(catdim)")) return C end +stagedfunction cat_insert!{T,N}(C::AbstractArray{T,N}, catdim::Int, catrange::UnitRange{Int}, X) + ranges = [d==N ? :(catrange) : :(1:size(C,$d)) for d=1:N] + ex = :(C[$(ranges...)] = X) + for n = N-1:-1:1 + ranges = [d==n ? :(catrange) : :(1:size(C,$d)) for d=1:N] + ex = Expr(:if,:(catdim==$n),:(C[$(ranges...)] = X),ex) + end + return ex +end + + vcat(A::AbstractArray...) = cat(1, A...) hcat(A::AbstractArray...) = cat(2, A...) diff --git a/test/arrayops.jl b/test/arrayops.jl index 9d2f16b513116..69d74c21d53af 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -84,14 +84,6 @@ tmp = zeros(Int,map(maximum,rng)...) tmp[rng...] = A[rng...] @test tmp == cat(3,zeros(Int,2,3),[0 0 0; 0 47 52],zeros(Int,2,3),[0 0 0; 0 127 132]) -@test cat([1,2],1,2,3.,4.,5.) == diagm([1,2,3.,4.,5.]) -blk = [1 2;3 4] -tmp = cat([1,3],blk,blk) -@test tmp[1:2,1:2,1] == blk -@test tmp[1:2,1:2,2] == zero(blk) -@test tmp[3:4,1:2,1] == zero(blk) -@test tmp[3:4,1:2,2] == blk - x = rand(2,2) b = x[1,:] @test isequal(size(b), (1, 2))