Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: more efficient cat #10037

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 54 additions & 77 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,102 +662,79 @@ function vcat{T}(A::AbstractMatrix{T}...)
end

## cat: general case
function cat(catdims, X...)
catdims = collect(catdims)
nargs = length(X)
ndimsX = Int[isa(a,AbstractArray) ? ndims(a) : 0 for a in X]
ndimsC = max(maximum(ndimsX), maximum(catdims))
catsizes = zeros(Int,(nargs,length(catdims)))
dims2cat = zeros(Int,ndimsC)
for k = 1:length(catdims)
dims2cat[catdims[k]]=k
end

typeC = isa(X[1],AbstractArray) ? eltype(X[1]) : typeof(X[1])
dimsC = Int[d <= ndimsX[1] ? size(X[1],d) : 1 for d=1:ndimsC]
for k = 1:length(catdims)
catsizes[1,k] = dimsC[catdims[k]]
end
for i = 2:nargs
typeC = promote_type(typeC, isa(X[i], AbstractArray) ? eltype(X[i]) : typeof(X[i]))
for d = 1:ndimsC
currentdim = (d <= ndimsX[i] ? size(X[i],d) : 1)
if dims2cat[d]==0
dimsC[d] == currentdim || throw(DimensionMismatch("mismatch in dimension $(d)"))
else
dimsC[d] += currentdim
catsizes[i,dims2cat[d]] = currentdim
function cat(catdim::Integer, X...)
catsize = 0
ndimsC = catdim
typeC = Bottom

for i=1:length(X)
if isa(X[i], AbstractArray)
if ndimsC < ndims(X[i])
ndimsC = ndims(X[i])
end
catsize += size(X[i], catdim)
typeC = promote_type(typeC, eltype(X[i]))
else
catsize += 1
typeC = promote_type(typeC, typeof(X[i]))
end
end

C = similar(isa(X[1],AbstractArray) ? full(X[1]) : [X[1]], typeC, tuple(dimsC...))
if length(catdims)>1
fill!(C,0)
end

offsets = zeros(Int,length(catdims))
for i=1:nargs
cat_one = [ dims2cat[d]==0 ? (1:dimsC[d]) : (offsets[dims2cat[d]]+(1:catsizes[i,dims2cat[d]])) for d=1:ndimsC]
C[cat_one...] = X[i]
for k = 1:length(catdims)
offsets[k] += catsizes[i,k]
end
if isa(X[1],AbstractArray)
dimsC = ntuple(d->(d==catdim ? catsize : size(X[1],d)), ndimsC)
C = similar(full(X[1]), typeC, dimsC)
else
dimsC = ntuple(d->(d==catdim ? catsize : 1), ndimsC)
C = Array(typeC, dimsC)
end
return C
cat!(C, catdim, X...)
end

vcat(X...) = cat(1, X...)
hcat(X...) = cat(2, X...)

cat{T}(catdims, A::AbstractArray{T}...) = cat_t(catdims, T, A...)

cat(catdims, A::AbstractArray...) =
cat_t(catdims, promote_eltype(A...), A...)
function cat{T}(catdim::Integer, X::AbstractArray{T}...)
catsize = 0
ndimsC = catdim

function cat_t(catdims, typeC, A::AbstractArray...)
catdims = collect(catdims)
nargs = length(A)
ndimsA = Int[ndims(a) for a in A]
ndimsC = max(maximum(ndimsA), maximum(catdims))
catsizes = zeros(Int,(nargs,length(catdims)))
dims2cat = zeros(Int,ndimsC)
for k = 1:length(catdims)
dims2cat[catdims[k]]=k
end

dimsC = Int[d <= ndimsA[1] ? size(A[1],d) : 1 for d=1:ndimsC]
for k = 1:length(catdims)
catsizes[1,k] = dimsC[catdims[k]]
end
for i = 2:nargs
for d = 1:ndimsC
currentdim = (d <= ndimsA[i] ? size(A[i],d) : 1)
if dims2cat[d]==0
dimsC[d] == currentdim || throw(DimensionMismatch("mismatch in dimension $(d)"))
else
dimsC[d] += currentdim
catsizes[i,dims2cat[d]] = currentdim
end
for i=1:length(X)
if ndimsC < ndims(X[i])
ndimsC = ndims(X[i])
end
catsize += size(X[i], catdim)
end

C = similar(full(A[1]), typeC, tuple(dimsC...))
if length(catdims)>1
fill!(C,0)
end
dimsC = ntuple(d->(d==catdim ? catsize : size(X[1],d)), ndimsC)
C = similar(full(X[1]), T, dimsC)
cat!(C, catdim, X...)
end

offsets = zeros(Int,length(catdims))
for i=1:nargs
cat_one = [ dims2cat[d]==0 ? (1:dimsC[d]) : (offsets[dims2cat[d]]+(1:catsizes[i,dims2cat[d]])) for d=1:ndimsC]
C[cat_one...] = A[i]
for k = 1:length(catdims)
offsets[k] += catsizes[i,k]
end
function cat!(C::AbstractArray, catdim::Integer, X...)
ndimsC = ndims(C)

offset = 0
for i=1:length(X)
catsize = isa(X[i], AbstractArray) ? size(X[i], catdim) : 1
catrange = offset + (1:catsize)
cat_insert!(C,catdim,catrange,X[i])
offset += catsize
end

size(C,catdim) == offset || throw(DimensionMismatch("mismatch in dimension $(catdim)"))
return C
end

stagedfunction cat_insert!{T,N}(C::AbstractArray{T,N}, catdim::Int, catrange::UnitRange{Int}, X)
ranges = [d==N ? :(catrange) : :(1:size(C,$d)) for d=1:N]
ex = :(C[$(ranges...)] = X)
for n = N-1:-1:1
ranges = [d==n ? :(catrange) : :(1:size(C,$d)) for d=1:N]
ex = Expr(:if,:(catdim==$n),:(C[$(ranges...)] = X),ex)
end
return ex
end


vcat(A::AbstractArray...) = cat(1, A...)
hcat(A::AbstractArray...) = cat(2, A...)

Expand Down
8 changes: 0 additions & 8 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,6 @@ tmp = zeros(Int,map(maximum,rng)...)
tmp[rng...] = A[rng...]
@test tmp == cat(3,zeros(Int,2,3),[0 0 0; 0 47 52],zeros(Int,2,3),[0 0 0; 0 127 132])

@test cat([1,2],1,2,3.,4.,5.) == diagm([1,2,3.,4.,5.])
blk = [1 2;3 4]
tmp = cat([1,3],blk,blk)
@test tmp[1:2,1:2,1] == blk
@test tmp[1:2,1:2,2] == zero(blk)
@test tmp[3:4,1:2,1] == zero(blk)
@test tmp[3:4,1:2,2] == blk

x = rand(2,2)
b = x[1,:]
@test isequal(size(b), (1, 2))
Expand Down