Skip to content

Commit

Permalink
improve inferrabilities by telling the compiler relational invariants (
Browse files Browse the repository at this point in the history
…JuliaLang#40594)

Our compiler doesn't understand these relations automatically yet.
  • Loading branch information
aviatesk authored and antoine-levitt committed May 9, 2021
1 parent e107332 commit fe8936f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
30 changes: 18 additions & 12 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,11 @@ module IteratorsMD

# comparison
@inline isless(I1::CartesianIndex{N}, I2::CartesianIndex{N}) where {N} = _isless(0, I1.I, I2.I)
@inline function _isless(ret, I1::NTuple{N,Int}, I2::NTuple{N,Int}) where N
newret = ifelse(ret==0, icmp(I1[N], I2[N]), ret)
_isless(newret, Base.front(I1), Base.front(I2))
@inline function _isless(ret, I1::Tuple{Int,Vararg{Int,N}}, I2::Tuple{Int,Vararg{Int,N}}) where {N}
newret = ifelse(ret==0, icmp(last(I1), last(I2)), ret)
t1, t2 = Base.front(I1), Base.front(I2)
# avoid dynamic dispatch by telling the compiler relational invariants
return isa(t1, Tuple{}) ? _isless(newret, (), ()) : _isless(newret, t1, t2::Tuple{Int,Vararg{Int}})
end
_isless(ret, ::Tuple{}, ::Tuple{}) = ifelse(ret==1, true, false)
icmp(a, b) = ifelse(isless(a,b), 1, ifelse(a==b, 0, -1))
Expand Down Expand Up @@ -168,6 +170,7 @@ module IteratorsMD
error("iteration is deliberately unsupported for CartesianIndex. Use `I` rather than `I...`, or use `Tuple(I)...`")

# Iteration
const OrdinalRangeInt = OrdinalRange{Int, Int}
"""
CartesianIndices(sz::Dims) -> R
CartesianIndices((istart:[istep:]istop, jstart:[jstep:]jstop, ...)) -> R
Expand Down Expand Up @@ -262,13 +265,13 @@ module IteratorsMD
For cartesian to linear index conversion, see [`LinearIndices`](@ref).
"""
struct CartesianIndices{N,R<:NTuple{N,OrdinalRange{Int, Int}}} <: AbstractArray{CartesianIndex{N},N}
struct CartesianIndices{N,R<:NTuple{N,OrdinalRangeInt}} <: AbstractArray{CartesianIndex{N},N}
indices::R
end

CartesianIndices(::Tuple{}) = CartesianIndices{0,typeof(())}(())
function CartesianIndices(inds::NTuple{N,OrdinalRange{<:Integer, <:Integer}}) where {N}
indices = map(r->convert(OrdinalRange{Int, Int}, r), inds)
indices = map(r->convert(OrdinalRangeInt, r), inds)
CartesianIndices{N, typeof(indices)}(indices)
end

Expand Down Expand Up @@ -394,19 +397,21 @@ module IteratorsMD
# `iterate` returns `Union{Nothing, Tuple}`, we explicitly pass a `valid` flag to eliminate
# the type instability inside the core `__inc` logic, and this gives better runtime performance.
__inc(::Tuple{}, ::Tuple{}) = false, ()
@inline function __inc(state::Tuple{Int}, indices::Tuple{<:OrdinalRange})
@inline function __inc(state::Tuple{Int}, indices::Tuple{OrdinalRangeInt})
rng = indices[1]
I = state[1] + step(rng)
valid = __is_valid_range(I, rng) && state[1] != last(rng)
return valid, (I, )
end
@inline function __inc(state, indices)
@inline function __inc(state::Tuple{Int,Int,Vararg{Int,N}}, indices::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt,N}}) where {N}
rng = indices[1]
I = state[1] + step(rng)
if __is_valid_range(I, rng) && state[1] != last(rng)
return true, (I, tail(state)...)
end
valid, I = __inc(tail(state), tail(indices))
t1, t2 = tail(state), tail(indices)
# avoid dynamic dispatch by telling the compiler relational invariants
valid, I = isa(t1, Tuple{Int}) ? __inc(t1, t2::Tuple{OrdinalRangeInt}) : __inc(t1, t2::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt}})
return valid, (first(rng), I...)
end

Expand Down Expand Up @@ -505,20 +510,21 @@ module IteratorsMD

# decrement post check to avoid integer overflow
@inline __dec(::Tuple{}, ::Tuple{}) = false, ()
@inline function __dec(state::Tuple{Int}, indices::Tuple{<:OrdinalRange})
@inline function __dec(state::Tuple{Int}, indices::Tuple{OrdinalRangeInt})
rng = indices[1]
I = state[1] - step(rng)
valid = __is_valid_range(I, rng) && state[1] != first(rng)
return valid, (I,)
end

@inline function __dec(state, indices)
@inline function __dec(state::Tuple{Int,Int,Vararg{Int,N}}, indices::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt,N}}) where {N}
rng = indices[1]
I = state[1] - step(rng)
if __is_valid_range(I, rng) && state[1] != first(rng)
return true, (I, tail(state)...)
end
valid, I = __dec(tail(state), tail(indices))
t1, t2 = tail(state), tail(indices)
# avoid dynamic dispatch by telling the compiler relational invariants
valid, I = isa(t1, Tuple{Int}) ? __dec(t1, t2::Tuple{OrdinalRangeInt}) : __dec(t1, t2::Tuple{OrdinalRangeInt,OrdinalRangeInt,Vararg{OrdinalRangeInt}})
return valid, (last(rng), I...)
end

Expand Down
12 changes: 8 additions & 4 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,14 @@ filter(f, t::Any32) = Tuple(filter(f, collect(t)))

## comparison ##

isequal(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _isequal(t1, t2)
_isequal(t1::Tuple{}, t2::Tuple{}) = true
_isequal(t1::Tuple{Any}, t2::Tuple{Any}) = isequal(t1[1], t2[1])
_isequal(t1::Tuple, t2::Tuple) = isequal(t1[1], t2[1]) && _isequal(tail(t1), tail(t2))
isequal(t1::Tuple, t2::Tuple) = length(t1) == length(t2) && _isequal(t1, t2)
_isequal(::Tuple{}, ::Tuple{}) = true
function _isequal(t1::Tuple{Any,Vararg{Any,N}}, t2::Tuple{Any,Vararg{Any,N}}) where {N}
isequal(t1[1], t2[1]) || return false
t1, t2 = tail(t1), tail(t2)
# avoid dynamic dispatch by telling the compiler relational invariants
return isa(t1, Tuple{}) ? true : _isequal(t1, t2::Tuple{Any,Vararg{Any}})
end
function _isequal(t1::Any32, t2::Any32)
for i = 1:length(t1)
if !isequal(t1[i], t2[i])
Expand Down

0 comments on commit fe8936f

Please sign in to comment.