Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reduce(cat(dims=4), A), with efficient method for simple cases #37196

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 115 additions & 6 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,34 @@ reduce(::typeof(vcat), A::AbstractVector{<:AbstractVecOrMat}) =
reduce(::typeof(hcat), A::AbstractVector{<:AbstractVecOrMat}) =
_typed_hcat(mapreduce(eltype, promote_type, A), A)

function _typed_cat(::Type{T}, A::AbstractArray{<:AbstractArray}, valg::Val=Val(0)) where {T}
ax1 = axes(first(A))
gap = ntuple(_->1, valg) # trivial dimensions
dense = true
for j in eachindex(A)
Aj = A[j]
if axes(Aj) != ax1
throw(ArgumentError("expected arrays of consistent size, got $(UnitRange.(axes(Aj))) for element $j, compared to $(UnitRange.(ax1)) for the first"))
end
dense &= isa(Aj, Array)
end
B = similar(first(A), T, ax1..., gap..., axes(A)...)
if dense
off = 1
for a in A
copyto!(B, off, a, 1, length(a))
off += length(a)
end
else
colons = map(_->Colon(), ax1)
for J in CartesianIndices(A)
ints = Tuple(J)
@inbounds B[colons..., gap..., ints...] = A[J]
end
end
return B
end

## cat: general case

# helper functions
Expand Down Expand Up @@ -1754,19 +1782,35 @@ typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2))
"""
cat(A...; dims=dims)

Concatenate the input arrays along the specified dimensions in the iterable `dims`. For
dimensions not in `dims`, all input arrays should have the same size, which will also be the
Concatenate the input arrays along the specified dimensions in the iterable `dims`.

For dimensions not in `dims`, all input arrays should have the same size, which will also be the
size of the output array along that dimension. For dimensions in `dims`, the size of the
output array is the sum of the sizes of the input arrays along that dimension. If `dims` is
a single number, the different arrays are tightly stacked along that dimension. If `dims` is
an iterable containing several dimensions, this allows one to construct block diagonal
matrices and their higher-dimensional analogues by simultaneously increasing several
dimensions for every new input array and putting zero blocks elsewhere. For example,
`cat(matrices...; dims=(1,2))` builds a block diagonal matrix, i.e. a block matrix with
`matrices[1]`, `matrices[2]`, ... as diagonal blocks and matching zero blocks away from the
diagonal.
dimensions for every new input array and putting zero blocks elsewhere.

# Examples
```jldoctest
julia> cat(ones(2,2), fill(pi,2), zeros(2,3,1); dims=2)
2×6×1 Array{Float64, 3}:
[:, :, 1] =
1.0 1.0 3.14159 0.0 0.0 0.0
1.0 1.0 3.14159 0.0 0.0 0.0

julia> cat([true], trues(2,2), trues(2,4); dims=(1,2))
5×7 Matrix{Bool}:
1 0 0 0 0 0 0
0 1 1 0 0 0 0
0 1 1 0 0 0 0
0 0 0 1 1 1 1
0 0 0 1 1 1 1
```
"""
@inline cat(A...; dims) = _cat(dims, A...)

_cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(T, A...; dims=catdims)

# The specializations for 1 and 2 inputs are important
Expand All @@ -1785,6 +1829,71 @@ typed_hcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(2))
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(2))
typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2))

"""
cat(; dims)

Returns a `Base.Fix1` function, equivalent to `(A...,) -> cat(A...; dims)`.

For some dimensions, `reduce(cat(;dims), A)` is done by an efficient method:
* `dims=1` is `reduce(vcat, A)`.
* `dims=2` is equivalent to `reduce(hcat, A)`.
* `dims=N+1`, when `A` a vector of `N`-dimensional arrays, is `reduce(cat, A)`.

In general `reduce(cat, A)` acts on an `M`-array containing `N`-arrays of uniform size
to return one `N+M`-array. This has `size(out) == (size(first(A))..., size(A)...)`,
and elements `out[J, I] == A[I][J]` where `I in CartesianIndices(A)`
and `J in CartesianIndices(first(A))`.

!!! compat "Julia 1.6"
These methods require at least Julia 1.6.

# Examples
```jldoctest
julia> reduce(cat(dims=3), [ones(2,4), fill(√2,2,4), [2 4 8 16; 32 64 128 256]])
2×4×3 Array{Float64, 3}:
[:, :, 1] =
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0

[:, :, 2] =
1.41421 1.41421 1.41421 1.41421
1.41421 1.41421 1.41421 1.41421

[:, :, 3] =
2.0 4.0 8.0 16.0
32.0 64.0 128.0 256.0

julia> reduce(cat(dims=4), [ones(2,3) for _ in 1:5]) |> size
(2, 3, 1, 5)

julia> reduce(cat, [rand(3,5) for μ in 1:7, ν in 0:10]) |> size
(3, 5, 7, 11)

julia> mapreduce(float, cat(dims=[1,2]), [1, fill(2,2,2), [3 4 5]])
4×6 Matrix{Float64}:
1.0 0.0 0.0 0.0 0.0 0.0
0.0 2.0 2.0 0.0 0.0 0.0
0.0 2.0 2.0 0.0 0.0 0.0
0.0 0.0 0.0 3.0 4.0 5.0
```
"""
cat(; dims) = length(dims)==1 ? Fix1(_cat, Val(dims...)) : Fix1(_cat, Tuple(dims))

reduce(::Fix1{typeof(_cat), Val{1}}, A::AbstractVector{<:AbstractVecOrMat}) = reduce(vcat, A)
reduce(::Fix1{typeof(_cat), Val{2}}, A::AbstractVector{<:AbstractVecOrMat}) = reduce(hcat, A)
reduce(::Fix1{typeof(_cat), Val{2}}, A::AbstractVector{<:AbstractVector}) = reduce(cat, A)

function reduce(f::Fix1{typeof(_cat), Val{M}}, A::AbstractVector{<:AbstractArray{<:Any,N}}) where {M,N}
if M > N
_typed_cat(mapreduce(eltype, promote_type, A), A, Val(M-N-1))
else
mapreduce(identity, f, A)
end
end

reduce(::typeof(cat), A::AbstractArray{<:AbstractArray}) =
_typed_cat(mapreduce(eltype, promote_type, A), A)

# 2d horizontal and vertical concatenation

function hvcat(nbc::Integer, as...)
Expand Down
1 change: 1 addition & 0 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ struct Fix1{F,T} <: Function
end

(f::Fix1)(y) = f.f(f.x, y)
(f::Fix1)(ys...) = f.f(f.x, ys...)

"""
Fix2(f, x)
Expand Down
4 changes: 3 additions & 1 deletion test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,11 @@ function test_cat(::Type{TestAbstractArray})
D = [1:24...]
i = rand(1:10)

@test cat(;dims=i) == Any[]
@test Base.typed_hcat(Float64) == Vector{Float64}()
@test Base.typed_vcat(Float64) == Vector{Float64}()

@test_skip cat(;dims=i) == Any[]

@test vcat() == Any[]
@test hcat() == Any[]
@test vcat(1, 1.0, 3, 3.0) == [1.0, 1.0, 3.0, 3.0]
Expand Down
4 changes: 4 additions & 0 deletions test/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ end
fy = Base.Fix2(/, y)
@test fx(y) == x / y
@test fy(x) == x / y

gx = Base.Fix1(*, x) # Vararg
@test gx() == x
@test gx(y, y) == x * y^2
end

@testset "curried comparisons" begin
Expand Down
12 changes: 12 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,18 @@ test18695(r) = sum( t^2 for t in r )
end
end

@testset "reduce(cat, A) for arrays" begin
for args in ([1:2], [[1, 2]], [1:2, 3:4], [[3, 4, 5], 1:3], [1:2, [3.5, 4.5]],
[[1 2; 3 4], [5 6; 7 8]])
X = reduce(cat, args)
Y = cat(args...; dims=ndims(args[1])+1)
@test X == Y
@test typeof(X) === typeof(Y)
end
@test_throws ArgumentError reduce(cat, [1:2, [1, 2], 1:3])
@test_throws MethodError reduce(cat, [[5 6; 7 8], [1, 2]])
end

# offset axes
i = Base.Slice(-3:3)
x = [j^2 for j in i]
Expand Down