From 7fe75d20267eff841e314a778deb5a6db879a6bc Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 8 Feb 2019 10:56:47 -0500 Subject: [PATCH] fix overflow in CartesianIndices iteration This allows LLVM to vectorize the 1D CartesianIndices case, as well as fixing an overflow bug for: ```julia CartesianIndices(((typemax(Int64)-2):typemax(Int64),)) ``` Co-authored-by: Yingbo Ma --- base/multidimensional.jl | 59 +++++++++++++++++++++++++++------------- test/cartesian.jl | 49 +++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 19 deletions(-) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index f03e641af10fc..72078fc74d8c6 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -95,6 +95,9 @@ module IteratorsMD # access to index tuple Tuple(index::CartesianIndex) = index.I + # equality + Base.:(==)(a::CartesianIndex{N}, b::CartesianIndex{N}) where N = a.I == b.I + # zeros and ones zero(::CartesianIndex{N}) where {N} = zero(CartesianIndex{N}) zero(::Type{CartesianIndex{N}}) where {N} = CartesianIndex(ntuple(x -> 0, Val(N))) @@ -142,11 +145,13 @@ module IteratorsMD # nextind and prevind with CartesianIndex function Base.nextind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N} iter = CartesianIndices(axes(a)) - return CartesianIndex(inc(i.I, first(iter).I, last(iter).I)) + _, I = inc((), i.I, first(iter).I, last(iter).I) + return I end function Base.prevind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N} iter = CartesianIndices(axes(a)) - return CartesianIndex(dec(i.I, last(iter).I, first(iter).I)) + _, I = dec((), i.I, first(iter).I, last(iter).I) + return I end # Iteration over the elements of CartesianIndex cannot be supported until its length can be inferred, @@ -334,20 +339,26 @@ module IteratorsMD iterfirst, iterfirst end @inline function iterate(iter::CartesianIndices, state) - nextstate = CartesianIndex(inc(state.I, first(iter).I, last(iter).I)) - nextstate.I[end] > last(iter.indices[end]) && return nothing - nextstate, nextstate + return inc((), state.I, first(iter).I, last(iter).I) end # increment & carry - @inline inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = () - @inline inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]+1,) - @inline function inc(state, start, stop) + # increment post check to avoid integer overflow + @inline inc(out, ::Tuple{}, ::Tuple{}, ::Tuple{}) = nothing + @inline function inc(out, state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) + if state[1] < stop[1] + nextstate = CartesianIndex(out..., state[1]+1) + return nextstate, nextstate + end + return nothing + end + + @inline function inc(out, state, start, stop) if state[1] < stop[1] - return (state[1]+1,tail(state)...) + nextstate = CartesianIndex(out..., state[1]+1, tail(state)...) + return nextstate, nextstate end - newtail = inc(tail(state), tail(start), tail(stop)) - (start[1], newtail...) + return inc((out..., start[1]), tail(state), tail(start), tail(stop)) end # 0-d cartesian ranges are special-cased to iterate once and only once @@ -414,21 +425,31 @@ module IteratorsMD iterfirst, iterfirst end @inline function iterate(r::Reverse{<:CartesianIndices}, state) - nextstate = CartesianIndex(dec(state.I, last(r.itr).I, first(r.itr).I)) - nextstate.I[end] < first(r.itr.indices[end]) && return nothing + valid, I = dec(state.I, last(r.itr).I, first(r.itr).I) + valid || return nothing + nextstate = CartesianIndex(I) nextstate, nextstate end # decrement & carry - @inline dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = () - @inline dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]-1,) - @inline function dec(state, start, stop) + # increment post check to avoid integer overflow + @inline dec(out, ::Tuple{}, ::Tuple{}, ::Tuple{}) = nothing + @inline function dec(out, state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) if state[1] > stop[1] - return (state[1]-1,tail(state)...) + nextstate = CartesianIndex(out..., state[1]-1) + return nextstate, nextstate end - newtail = dec(tail(state), tail(start), tail(stop)) - (start[1], newtail...) + return nothing end + + @inline function dec(out, state, start, stop) + if state[1] > stop[1] + nextstate = CartesianIndex(out..., state[1]-1, tail(state)...) + return nextstate, nextstate + end + return dec((out..., start[1]), tail(state), tail(start), tail(stop)) + end + # 0-d cartesian ranges are special-cased to iterate once and only once iterate(iter::Reverse{<:CartesianIndices{0}}, state=false) = state ? nothing : (CartesianIndex(), true) diff --git a/test/cartesian.jl b/test/cartesian.jl index 40badf6bb24bb..7de79bc6a407b 100644 --- a/test/cartesian.jl +++ b/test/cartesian.jl @@ -15,3 +15,52 @@ ex = Base.Cartesian.exprresolve(:(if 5 > 4; :x; else :y; end)) # can't convert higher-dimensional indices to Int @test_throws MethodError convert(Int, CartesianIndex(42, 1)) end + +@testset "CartesianIndices overflow" begin + I = CartesianIndices((1:typemax(Int),)) + i = last(I) + @test iterate(I, i) === nothing + + I = CartesianIndices((1:(typemax(Int)-1),)) + i = CartesianIndex(typemax(Int)) + @test iterate(I, i) === nothing + + I = CartesianIndices((1:typemax(Int), 1:typemax(Int))) + i = last(I) + @test iterate(I, i) === nothing + + i = CartesianIndex(typemax(Int), 1) + @test iterate(I, i) === (CartesianIndex(1, 2), CartesianIndex(1,2)) + + # reverse cartesian indices + I = CartesianIndices((typemin(Int):(typemin(Int)+3),)) + i = last(I) + @test iterate(I, i) === nothing +end + +@testset "CartesianIndices iteration" begin + I = CartesianIndices((2:4, 0:1, 1:1, 3:5)) + indices = Vector{eltype(I)}() + for i in I + push!(indices, i) + end + @test length(I) == length(indices) + @test vec(I) == indices + + empty!(indices) + I = Iterators.reverse(I) + for i in I + push!(indices, i) + end + @test length(I) == length(indices) + @test vec(collect(I)) == indices + + # test invalid state + I = CartesianIndices((2:4, 3:5)) + @test iterate(I, CartesianIndex(typemax(Int), 3))[1] == CartesianIndex(2,4) + @test iterate(I, CartesianIndex(typemax(Int), 4))[1] == CartesianIndex(2,5) + @test iterate(I, CartesianIndex(typemax(Int), 5)) === nothing + + @test iterate(I, CartesianIndex(3, typemax(Int)))[1] == CartesianIndex(4,typemax(Int)) + @test iterate(I, CartesianIndex(4, typemax(Int))) === nothing +end