Skip to content

Commit

Permalink
Use lispy tuples in cat (fixes #21673) (#39314)
Browse files Browse the repository at this point in the history
The `cat` pipeline has long had poor inferrability.
Together with #39292 and #39294, this should basically put an
end to that problem.

Together, at least in simple cases these make the performance
of `cat` essentially equivalent to the manual version.
In other words, the `test1` and `test2` of #21673 benchmark
very similarly.

(cherry picked from commit 78d55e2)
  • Loading branch information
timholy authored and staticfloat committed Dec 22, 2022
1 parent 83c6f8d commit 58bdaf5
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1645,28 +1645,29 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
return __cat(A, shape, catdims, X...)
end

function __cat(A, shape::NTuple{M}, catdims, X...) where M
N = M::Int
offsets = zeros(Int, N)
inds = Vector{UnitRange{Int}}(undef, N)
concat = copyto!(zeros(Bool, N), catdims)
for x in X
for i = 1:N
if concat[i]
inds[i] = offsets[i] .+ cat_indices(x, i)
offsets[i] += cat_size(x, i)
else
inds[i] = 1:shape[i]
end
end
I::NTuple{N, UnitRange{Int}} = (inds...,)
if x isa AbstractArray
A[I...] = x
else
fill!(view(A, I...), x)
end
# Why isn't this called `__cat!`?
__cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)

function __cat_offset!(A, shape, catdims, offsets, x, X...)
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
return __cat_offset!(A, shape, catdims, newoffsets, X...)
end
__cat_offset!(A, shape, catdims, offsets) = A

function __cat_offset1!(A, shape, catdims, offsets, x)
inds = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
end
if x isa AbstractArray
A[inds...] = x
else
fill!(view(A, inds...), x)
end
newoffsets = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
end
return A
return newoffsets
end

"""
Expand Down

0 comments on commit 58bdaf5

Please sign in to comment.