Skip to content

Commit

Permalink
cat redesign
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Feb 19, 2015
1 parent ab5c4cf commit dbbccfa
Show file tree
Hide file tree
Showing 9 changed files with 413 additions and 480 deletions.
298 changes: 0 additions & 298 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -594,303 +594,6 @@ end

get(A::AbstractArray, I::RangeVecIntList, default) = get!(similar(A, typeof(default), map(length, I)...), A, I, default)


## Concatenation ##

promote_eltype() = Bottom
promote_eltype(v1, vs...) = promote_type(eltype(v1), promote_eltype(vs...))

#TODO: ERROR CHECK
cat(catdim::Integer) = Array(Any, 0)

vcat() = Array(Any, 0)
hcat() = Array(Any, 0)

## cat: special cases
hcat{T}(X::T...) = T[ X[j] for i=1, j=1:length(X) ]
hcat{T<:Number}(X::T...) = T[ X[j] for i=1, j=1:length(X) ]
vcat{T}(X::T...) = T[ X[i] for i=1:length(X) ]
vcat{T<:Number}(X::T...) = T[ X[i] for i=1:length(X) ]

function vcat(X::Number...)
T = promote_typeof(X...)
hvcat_fill(Array(T,length(X)), X)
end

function hcat(X::Number...)
T = promote_typeof(X...)
hvcat_fill(Array(T,1,length(X)), X)
end

function vcat{T}(V::AbstractVector{T}...)
n = 0
for Vk in V
n += length(Vk)
end
a = similar(full(V[1]), n)
pos = 1
for k=1:length(V)
Vk = V[k]
p1 = pos+length(Vk)-1
a[pos:p1] = Vk
pos = p1+1
end
a
end

function hcat{T}(A::AbstractVecOrMat{T}...)
nargs = length(A)
nrows = size(A[1], 1)
ncols = 0
dense = true
for j = 1:nargs
Aj = A[j]
if size(Aj, 1) != nrows
throw(ArgumentError("number of rows must match"))
end
dense &= isa(Aj,Array)
nd = ndims(Aj)
ncols += (nd==2 ? size(Aj,2) : 1)
end
B = similar(full(A[1]), nrows, ncols)
pos = 1
if dense
for k=1:nargs
Ak = A[k]
n = length(Ak)
copy!(B, pos, Ak, 1, n)
pos += n
end
else
for k=1:nargs
Ak = A[k]
p1 = pos+(isa(Ak,AbstractMatrix) ? size(Ak, 2) : 1)-1
B[:, pos:p1] = Ak
pos = p1+1
end
end
return B
end

function vcat{T}(A::AbstractMatrix{T}...)
nargs = length(A)
nrows = sum(a->size(a, 1), A)::Int
ncols = size(A[1], 2)
for j = 2:nargs
if size(A[j], 2) != ncols
throw(ArgumentError("number of columns must match"))
end
end
B = similar(full(A[1]), nrows, ncols)
pos = 1
for k=1:nargs
Ak = A[k]
p1 = pos+size(Ak,1)-1
B[pos:p1, :] = Ak
pos = p1+1
end
return B
end

## cat: general case

function cat(catdims, X...)
T = promote_type(map(x->isa(x,AbstractArray) ? eltype(x) : typeof(x), X)...)
cat_t(catdims, T, X...)
end

function cat_t(catdims, typeC::Type, X...)
catdims = collect(catdims)
nargs = length(X)
ndimsX = Int[isa(a,AbstractArray) ? ndims(a) : 0 for a in X]
ndimsC = max(maximum(ndimsX), maximum(catdims))
catsizes = zeros(Int,(nargs,length(catdims)))
dims2cat = zeros(Int,ndimsC)
for k = 1:length(catdims)
dims2cat[catdims[k]]=k
end

dimsC = Int[d <= ndimsX[1] ? size(X[1],d) : 1 for d=1:ndimsC]
for k = 1:length(catdims)
catsizes[1,k] = dimsC[catdims[k]]
end
for i = 2:nargs
for d = 1:ndimsC
currentdim = (d <= ndimsX[i] ? size(X[i],d) : 1)
if dims2cat[d]==0
dimsC[d] == currentdim || throw(DimensionMismatch("mismatch in dimension $(d)"))
else
dimsC[d] += currentdim
catsizes[i,dims2cat[d]] = currentdim
end
end
end

C = similar(isa(X[1],AbstractArray) ? full(X[1]) : [X[1]], typeC, tuple(dimsC...))
if length(catdims)>1
fill!(C,0)
end

offsets = zeros(Int,length(catdims))
for i=1:nargs
cat_one = [ dims2cat[d]==0 ? (1:dimsC[d]) : (offsets[dims2cat[d]]+(1:catsizes[i,dims2cat[d]])) for d=1:ndimsC]
C[cat_one...] = X[i]
for k = 1:length(catdims)
offsets[k] += catsizes[i,k]
end
end
return C
end

vcat(X...) = cat(1, X...)
hcat(X...) = cat(2, X...)

typed_vcat(T::Type, X...) = cat_t(1, T, X...)
typed_hcat(T::Type, X...) = cat_t(2, T, X...)

cat{T}(catdims, A::AbstractArray{T}...) = cat_t(catdims, T, A...)

cat(catdims, A::AbstractArray...) = cat_t(catdims, promote_eltype(A...), A...)

vcat(A::AbstractArray...) = cat(1, A...)
hcat(A::AbstractArray...) = cat(2, A...)

typed_vcat(T::Type, A::AbstractArray...) = cat_t(1, T, A...)
typed_hcat(T::Type, A::AbstractArray...) = cat_t(2, T, A...)

# 2d horizontal and vertical concatenation

function hvcat(nbc::Integer, as...)
# nbc = # of block columns
n = length(as)
if mod(n,nbc) != 0
throw(ArgumentError("all rows must have the same number of block columns"))
end
nbr = div(n,nbc)
hvcat(ntuple(nbr, i->nbc), as...)
end

function hvcat{T}(rows::(Int...), as::AbstractMatrix{T}...)
nbr = length(rows) # number of block rows

nc = 0
for i=1:rows[1]
nc += size(as[i],2)
end

nr = 0
a = 1
for i = 1:nbr
nr += size(as[a],1)
a += rows[i]
end

out = similar(full(as[1]), T, nr, nc)

a = 1
r = 1
for i = 1:nbr
c = 1
szi = size(as[a],1)
for j = 1:rows[i]
Aj = as[a+j-1]
szj = size(Aj,2)
if size(Aj,1) != szi
throw(ArgumentError("mismatched height in block row $(i)"))
end
if c-1+szj > nc
throw(ArgumentError("block row $(i) has mismatched number of columns"))
end
out[r:r-1+szi, c:c-1+szj] = Aj
c += szj
end
if c != nc+1
throw(ArgumentError("block row $(i) has mismatched number of columns"))
end
r += szi
a += rows[i]
end
out
end

hvcat(rows::(Int...)) = []

function hvcat{T<:Number}(rows::(Int...), xs::T...)
nr = length(rows)
nc = rows[1]

a = Array(T, nr, nc)
if length(a) != length(xs)
throw(ArgumentError("argument count does not match specified shape"))
end
k = 1
@inbounds for i=1:nr
if nc != rows[i]
throw(ArgumentError("row $(i) has mismatched number of columns"))
end
for j=1:nc
a[i,j] = xs[k]
k += 1
end
end
a
end

function hvcat_fill(a, xs)
k = 1
nr, nc = size(a,1), size(a,2)
for i=1:nr
@inbounds for j=1:nc
a[i,j] = xs[k]
k += 1
end
end
a
end

function typed_hvcat(T::Type, rows::(Int...), xs::Number...)
nr = length(rows)
nc = rows[1]
for i = 2:nr
if nc != rows[i]
throw(ArgumentError("row $(i) has mismatched number of columns"))
end
end
len = length(xs)
if nr*nc != len
throw(ArgumentError("argument count $(len) does not match specified shape $((nr,nc))"))
end
hvcat_fill(Array(T, nr, nc), xs)
end

function hvcat(rows::(Int...), xs::Number...)
T = promote_typeof(xs...)
typed_hvcat(T, rows, xs...)
end

# fallback definition of hvcat in terms of hcat and vcat
function hvcat(rows::(Int...), as...)
nbr = length(rows) # number of block rows
rs = cell(nbr)
a = 1
for i = 1:nbr
rs[i] = hcat(as[a:a-1+rows[i]]...)
a += rows[i]
end
vcat(rs...)
end

function typed_hvcat(T::Type, rows::(Int...), as...)
nbr = length(rows) # number of block rows
rs = cell(nbr)
a = 1
for i = 1:nbr
rs[i] = hcat(as[a:a-1+rows[i]]...)
a += rows[i]
end
T[rs...;]
end

## Reductions and scans ##

function isequal(A::AbstractArray, B::AbstractArray)
Expand Down Expand Up @@ -1496,4 +1199,3 @@ function randsubseq!(S::AbstractArray, A::AbstractArray, p::Real)
end

randsubseq{T}(A::AbstractArray{T}, p::Real) = randsubseq!(T[], A, p)

34 changes: 4 additions & 30 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,36 +1045,10 @@ function reverse!(v::StridedVector, s=1, n=length(v))
v
end

function vcat{T}(arrays::Vector{T}...)
n = 0
for a in arrays
n += length(a)
end
arr = Array(T, n)
ptr = pointer(arr)
offset = 0
if isbits(T)
elsz = sizeof(T)
else
elsz = div(WORD_SIZE,8)
end
for a in arrays
nba = length(a)*elsz
ccall(:memcpy, Ptr{Void}, (Ptr{Void}, Ptr{Void}, UInt),
ptr+offset, a, nba)
offset += nba
end
return arr
end

function hcat{T}(V::Vector{T}...)
height = length(V[1])
for j = 2:length(V)
if length(V[j]) != height
throw(DimensionMismatch("vectors must have same lengths"))
end
end
[ V[j][i]::T for i=1:length(V[1]), j=1:length(V) ]
vcat_fill!{T}(C::Vector{T}, catrange, x::Vector) = copy!(C,first(catrange), x,1,length(x))
hcat_fill!{T}(C::Matrix{T}, catrange, x::VecOrMat) = begin
size(C,1)==size(x,1) || throw(ArgumentError("number of rows must match"))
copy!(C,size(C,1)*(first(catrange)-1)+1,x,1,length(x))
end

## find ##
Expand Down
Loading

0 comments on commit dbbccfa

Please sign in to comment.