From b0059267171ddbb250ecb8d5ff1d252d9386b74b Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 9 Feb 2022 17:13:03 +0800 Subject: [PATCH 1/2] Fix `stride(A, i)` for 0-dim inputs --- base/abstractarray.jl | 2 +- base/reinterpretarray.jl | 2 ++ test/abstractarray.jl | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 5c4188b0f1471..524b4daa53fc7 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -546,7 +546,7 @@ julia> stride(A,3) function stride(A::AbstractArray, k::Integer) st = strides(A) k ≤ ndims(A) && return st[k] - return sum(st .* size(A)) + return ndims(A) == 0 ? 1 : sum(st .* size(A)) end @inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...) diff --git a/base/reinterpretarray.jl b/base/reinterpretarray.jl index 08638d26249a6..02707f4b02780 100644 --- a/base/reinterpretarray.jl +++ b/base/reinterpretarray.jl @@ -149,6 +149,8 @@ StridedMatrix{T} = StridedArray{T,2} StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}} strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...) +stride(A::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, k::Integer) = + k ≤ ndims(A) ? strides(A)[k] : length(A) function strides(a::ReshapedReinterpretArray) ap = parent(a) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index a33cf53698d1c..300728a4b5f33 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1580,6 +1580,19 @@ end end end +@testset "stride for 0 dims array #44087" begin + struct Fill44087 <: AbstractArray{Int,0} + a::Int + end + # `stride` shouldn't work if `strides` is not defined. + @test_throws MethodError stride(Fill44087(1), 1) + # It is intentionally to only check the return type. (The value is somehow arbitrary) + @test stride(fill(1), 1) isa Int + @test stride(reinterpret(Float64, fill(Int64(1))), 1) isa Int + @test stride(reinterpret(reshape, Float64, fill(Int64(1))), 1) isa Int + @test stride(Base.ReshapedArray(fill(1), (), ()), 1) isa Int +end + @testset "to_indices inference (issue #42001 #44059)" begin @test (@inferred to_indices([], ntuple(Returns(CartesianIndex(1)), 32))) == ntuple(Returns(1), 32) @test (@inferred to_indices([], ntuple(Returns(CartesianIndices(1:1)), 32))) == ntuple(Returns(Base.OneTo(1)), 32) From eb6b43ab36659abc88945ec465c9814f7c56181c Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sat, 12 Feb 2022 01:39:55 +0800 Subject: [PATCH 2/2] Replace `sum( .* )` with for loop. --- base/abstractarray.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 524b4daa53fc7..e42b5d8c204f7 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -546,7 +546,13 @@ julia> stride(A,3) function stride(A::AbstractArray, k::Integer) st = strides(A) k ≤ ndims(A) && return st[k] - return ndims(A) == 0 ? 1 : sum(st .* size(A)) + ndims(A) == 0 && return 1 + sz = size(A) + s = st[1] * sz[1] + for i in 2:ndims(A) + s += st[i] * sz[i] + end + return s end @inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)