Skip to content

Commit

Permalink
fix overflow in CartesianIndices iteration
Browse files Browse the repository at this point in the history
This allows LLVM to vectorize the 1D CartesianIndices case,
as well as fixing an overflow bug for:

```julia
CartesianIndices(((typemax(Int64)-2):typemax(Int64),))
```

Co-authored-by: Yingbo Ma <mayingbo5@gmail.com>
  • Loading branch information
vchuravy and YingboMa committed Feb 9, 2019
1 parent a0bb006 commit d7805f7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 17 deletions.
46 changes: 29 additions & 17 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ module IteratorsMD
# access to index tuple
Tuple(index::CartesianIndex) = index.I

# equality
import Base: ==
==(a::CartesianIndex{N}, b::CartesianIndex{N}) where N = a.I == b.I

# zeros and ones
zero(::CartesianIndex{N}) where {N} = zero(CartesianIndex{N})
zero(::Type{CartesianIndex{N}}) where {N} = CartesianIndex(ntuple(x -> 0, Val(N)))
Expand Down Expand Up @@ -142,11 +146,13 @@ module IteratorsMD
# nextind and prevind with CartesianIndex
function Base.nextind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
iter = CartesianIndices(axes(a))
return CartesianIndex(inc(i.I, first(iter).I, last(iter).I))
valid, I = inc(i.I, first(iter).I, last(iter).I)
return CartesianIndex(I)
end
function Base.prevind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
iter = CartesianIndices(axes(a))
return CartesianIndex(dec(i.I, last(iter).I, first(iter).I))
valid, I = dec(i.I, first(iter).I, last(iter).I)
return CartesianIndex(I)
end

# Iteration over the elements of CartesianIndex cannot be supported until its length can be inferred,
Expand Down Expand Up @@ -334,20 +340,25 @@ module IteratorsMD
iterfirst, iterfirst
end
@inline function iterate(iter::CartesianIndices, state)
nextstate = CartesianIndex(inc(state.I, first(iter).I, last(iter).I))
nextstate.I[end] > last(iter.indices[end]) && return nothing
nextstate, nextstate
# If we increment before the condition check, we run the
# risk of an integer overflow.
valid, I = inc(state.I, first(iter).I, last(iter).I)
valid || return nothing
nextstate = CartesianIndex(I)
return nextstate, nextstate
end

# increment & carry
@inline inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
@inline inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]+1,)
@inline inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = true, ()
@inline function inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int})
state[1] < stop[1], (state[1]+1,)
end
@inline function inc(state, start, stop)
if state[1] < stop[1]
return (state[1]+1,tail(state)...)
return true, (state[1]+1,tail(state)...)
end
newtail = inc(tail(state), tail(start), tail(stop))
(start[1], newtail...)
valid, newtail = inc(tail(state), tail(start), tail(stop))
valid, (start[1], newtail...)
end

# 0-d cartesian ranges are special-cased to iterate once and only once
Expand Down Expand Up @@ -414,20 +425,21 @@ module IteratorsMD
iterfirst, iterfirst
end
@inline function iterate(r::Reverse{<:CartesianIndices}, state)
nextstate = CartesianIndex(dec(state.I, last(r.itr).I, first(r.itr).I))
nextstate.I[end] < first(r.itr.indices[end]) && return nothing
valid, I = dec(state.I, last(r.itr).I, first(r.itr).I)
valid || return nothing
nextstate = CartesianIndex(I)
nextstate, nextstate
end

# decrement & carry
@inline dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
@inline dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]-1,)
@inline dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = true, ()
@inline dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = state[1] > stop[1], (state[1]-1,)
@inline function dec(state, start, stop)
if state[1] > stop[1]
return (state[1]-1,tail(state)...)
return true, (state[1]-1,tail(state)...)
end
newtail = dec(tail(state), tail(start), tail(stop))
(start[1], newtail...)
valid, newtail = dec(tail(state), tail(start), tail(stop))
valid, (start[1], newtail...)
end
# 0-d cartesian ranges are special-cased to iterate once and only once
iterate(iter::Reverse{<:CartesianIndices{0}}, state=false) = state ? nothing : (CartesianIndex(), true)
Expand Down
49 changes: 49 additions & 0 deletions test/cartesian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,52 @@ ex = Base.Cartesian.exprresolve(:(if 5 > 4; :x; else :y; end))
# can't convert higher-dimensional indices to Int
@test_throws MethodError convert(Int, CartesianIndex(42, 1))
end

@testset "CartesianIndices overflow" begin
I = CartesianIndices((1:typemax(Int),))
i = last(I)
@test iterate(I, i) === nothing

I = CartesianIndices((1:(typemax(Int)-1),))
i = CartesianIndex(typemax(Int))
@test iterate(I, i) === nothing

I = CartesianIndices((1:typemax(Int), 1:typemax(Int)))
i = last(I)
@test iterate(I, i) === nothing

i = CartesianIndex(typemax(Int), 1)
@test iterate(I, i) === (CartesianIndex(1, 2), CartesianIndex(1,2))

# reverse cartesian indices
I = CartesianIndices((typemin(Int):(typemin(Int)+3),))
i = last(I)
@test iterate(I, i) === nothing
end

@testset "CartesianIndices iteration" begin
I = CartesianIndices((2:4, 0:1, 1:1, 3:5))
indices = Vector{eltype(I)}()
for i in I
push!(indices, i)
end
@test length(I) == length(indices)
@test vec(I) == indices

empty!(indices)
I = Iterators.reverse(I)
for i in I
push!(indices, i)
end
@test length(I) == length(indices)
@test vec(collect(I)) == indices

# test invalid state
I = CartesianIndices((2:4, 3:5))
@test iterate(I, CartesianIndex(typemax(Int), 3))[1] == CartesianIndex(2,4)
@test iterate(I, CartesianIndex(typemax(Int), 4))[1] == CartesianIndex(2,5)
@test iterate(I, CartesianIndex(typemax(Int), 5)) === nothing

@test iterate(I, CartesianIndex(3, typemax(Int)))[1] == CartesianIndex(4,typemax(Int))
@test iterate(I, CartesianIndex(4, typemax(Int))) === nothing
end

0 comments on commit d7805f7

Please sign in to comment.