Skip to content

Commit

Permalink
add dcat and blkdiag, update tests and perf
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Feb 10, 2015
1 parent ae1d156 commit 02cb581
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
53 changes: 48 additions & 5 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,9 @@ function typed_cat(T::Type, catdim::Integer, X...)
C=cat_similar(X[1], T, catsize(n, catdim, N, X[1]))
offset = 0
for i = 1:length(X)
r = offset+catrange(catdim, X[i])
cat_fill!(C, catdim, r, X[i])
x = X[i]
r = offset+catrange(catdim, x)
cat_fill!(C, catdim, r, x)
offset = last(r)
end
C
Expand All @@ -790,8 +791,7 @@ function typed_hvcat(T::Type, rows::(Int...), X...)
k = 1
N = catndims(X[k])
for i = 2:length(X)
x = X[i]
Ni = catndims(x)
Ni = catndims(X[i])
Ni > N && (k=i;N=Ni)
end

Expand All @@ -818,7 +818,7 @@ end

hvcat{T}(rows::(Int...), X::T...) = typed_hvcat(cattype(T), rows, X...)
hvcat{T}(rows::(Int...), X::AbstractArray{T}...) = typed_hvcat(T, rows, X...)
hvcat(rows::(Int...), X...) = typed_hcat(promote_cattypeof(X...), rows, X...)
hvcat(rows::(Int...), X...) = typed_hvcat(promote_cattypeof(X...), rows, X...)

function hvcat(nbc::Integer, A...)
# nbc = # of block columns
Expand All @@ -830,6 +830,49 @@ function hvcat(nbc::Integer, A...)
hvcat(ntuple(i->nbc, nbr), A...)
end

## dcat: cat along diagonals
function typed_dcat(T::Type, catdims, X...)
dims = Int[d for d in catdims]
M = length(dims)
N = maximum(dims)
for i = 1:length(X)
N = max(N,catndims(X[i]))
end

catsizes = zeros(Int, N)
for i = 1:length(X)
x = X[i]
for j = 1:M
catdim = dims[j]
catsizes[catdim] += catlength(catdim,x)
end
end
for d = 1:N
catsizes[d] == 0 && (catsizes[d]=catlength(d,X[1]))
end
C=cat_similar(X[1], T, tuple(catsizes...))
M > 1 && fill!(C,zero(T))

offsets = zeros(Int, M)
catranges = [1:size(C,d) for d=1:N]
for i = 1:length(X)
x=X[i]
for j=1:M
catdim = dims[j]
catranges[catdim] = offsets[j] + (1:catlength(catdim, x))
offsets[j] = last(catranges[catdim])
end
C[catranges...] = x
end
C
end

dcat{T}(catdims, X::T...) = typed_dcat(cattype(T), catdims, X...)
dcat{T}(catdims, X::AbstractArray{T}...) = typed_dcat(T, catdims, X...)
dcat(catdims, X...) = typed_dcat(promote_cattypeof(X...), catdims, X...)

blkdiag(X...) = dcat([1,2], X...)

## Reductions and scans ##

function isequal(A::AbstractArray, B::AbstractArray)
Expand Down
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ export

# arrays
bitbroadcast,
blkdiag,
broadcast!,
broadcast!_function,
broadcast,
Expand All @@ -523,6 +524,7 @@ export
cumsum,
cumsum!,
cumsum_kbn,
dcat,
eachindex,
extrema,
fill!,
Expand Down
4 changes: 2 additions & 2 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ tmp = zeros(Int,map(maximum,rng)...)
tmp[rng...] = A[rng...]
@test tmp == cat(3,zeros(Int,2,3),[0 0 0; 0 47 52],zeros(Int,2,3),[0 0 0; 0 127 132])

@test cat([1,2],1,2,3.,4.,5.) == diagm([1,2,3.,4.,5.])
@test dcat([1,2],1,2,3.,4.,5.) == diagm([1,2,3.,4.,5.])
blk = [1 2;3 4]
tmp = cat([1,3],blk,blk)
tmp = dcat([1,3],blk,blk)
@test tmp[1:2,1:2,1] == blk
@test tmp[1:2,1:2,2] == zero(blk)
@test tmp[3:4,1:2,1] == zero(blk)
Expand Down

0 comments on commit 02cb581

Please sign in to comment.