From e387e9c65c4ea5837c8c41bdd698a3ffff1f1c2f Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 19 Jan 2021 13:19:56 -0600 Subject: [PATCH] Improve inferability of shape::Dims for cat (#39294) `cat` is often called with Varargs or heterogenous inputs, and inference almost always fails. Even when all the arrays are of the same type, if the number of varargs isn't known inference typically fails. The culprit is probably #36454. This reduces the number of failures considerably, by avoiding creation of vararg length tuples in the shape-inference pipeline. --- base/abstractarray.jl | 8 +++++++- test/abstractarray.jl | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 30072363a34c3..1f1120740e99a 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1580,6 +1580,7 @@ cat_indices(A::AbstractArray, d) = axes(A, d) cat_similar(A, ::Type{T}, shape) where T = Array{T}(undef, shape) cat_similar(A::AbstractArray, ::Type{T}, shape) where T = similar(A, T, shape) +# These are for backwards compatibility (even though internal) cat_shape(dims, shape::Tuple{Vararg{Int}}) = shape function cat_shape(dims, shapes::Tuple) out_shape = () @@ -1588,6 +1589,11 @@ function cat_shape(dims, shapes::Tuple) end return out_shape end +# The new way to compute the shape (more inferrable than combining cat_size & cat_shape, due to Varargs + issue#36454) +cat_size_shape(dims) = ntuple(zero, Val(length(dims))) +@inline cat_size_shape(dims, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, (), cat_size(X)), tail...) +_cat_size_shape(dims, shape) = shape +@inline _cat_size_shape(dims, shape, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, shape, cat_size(X)), tail...) _cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = () _cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape @@ -1631,7 +1637,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) @inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) @inline function _cat_t(dims, ::Type{T}, X...) where {T} catdims = dims2cat(dims) - shape = cat_shape(catdims, map(cat_size, X)) + shape = cat_size_shape(catdims, X...) A = cat_similar(X[1], T, shape) if count(!iszero, catdims)::Int > 1 fill!(A, zero(T)) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index f00f1f80332bb..52af916acbdac 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -692,6 +692,12 @@ function test_cat(::Type{TestAbstractArray}) # 36041 @test_throws MethodError cat(["a"], ["b"], dims=[1, 2]) @test cat([1], [1], dims=[1, 2]) == I(2) + + # inferrability + As = [zeros(2, 2) for _ = 1:2] + @test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2) + cat3v(As) = cat(As...; dims=Val(3)) + @test @inferred(cat3v(As)) == zeros(2, 2, 2) end function test_ind2sub(::Type{TestAbstractArray})