From 3b2407adbaf1ea9e4aaf14399826862ffdf13db2 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 19 Jan 2021 02:51:19 -0500 Subject: [PATCH] Use lispy tuples in cat (fixes #21673) The `cat` pipeline has long had poor inferrability. Together with #39292 and #39294, this should basically put an end to that problem. Together, at least in simple cases these make the performance of `cat` essentially equivalent to the manual version. In other words, the `test1` and `test2` of #21673 benchmark very similarly. --- base/abstractarray.jl | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 30072363a34c3..fd65b9990d642 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1639,29 +1639,24 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) return __cat(A, shape, catdims, X...) end -function __cat(A, shape::NTuple{M}, catdims, X...) where M - N = M::Int - offsets = zeros(Int, N) - inds = Vector{UnitRange{Int}}(undef, N) - concat = copyto!(zeros(Bool, N), catdims) - for x in X - for i = 1:N - if concat[i] - inds[i] = offsets[i] .+ cat_indices(x, i) - offsets[i] += cat_size(x, i) - else - inds[i] = 1:shape[i] - end - end - I::NTuple{N, UnitRange{Int}} = (inds...,) - if x isa AbstractArray - A[I...] = x - else - fill!(view(A, I...), x) - end +# Why isn't this called `__cat!`? +__cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) + +function __cat_offset!(A, shape, catdims, offsets, x, X...) + inds = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i] + end + if x isa AbstractArray + A[inds...] = x + else + fill!(view(A, inds...), x) + end + newoffsets = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] end - return A + return __cat_offset!(A, shape, catdims, newoffsets, X...) end +__cat_offset!(A, shape, catdims, offsets) = return A """ vcat(A...)