From bb0bff9a3358ad4eb2d8c09368397b5bd0f1d58e Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Sun, 25 Apr 2021 20:11:53 +0900 Subject: [PATCH] improve inferrabilities by telling the compiler relational invariants Our compiler doesn't understand these relations automatically yet. --- base/multidimensional.jl | 28 +++++++++++++++++----------- base/tuple.jl | 12 ++++++++---- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 5ef46bb1da468..87309308d411d 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -125,9 +125,11 @@ module IteratorsMD # comparison @inline isless(I1::CartesianIndex{N}, I2::CartesianIndex{N}) where {N} = _isless(0, I1.I, I2.I) - @inline function _isless(ret, I1::NTuple{N,Int}, I2::NTuple{N,Int}) where N - newret = ifelse(ret==0, icmp(I1[N], I2[N]), ret) - _isless(newret, Base.front(I1), Base.front(I2)) + @inline function _isless(ret, I1::Tuple{Int,Vararg{Int,N}}, I2::Tuple{Int,Vararg{Int,N}}) where {N} + newret = ifelse(ret==0, icmp(last(I1), last(I2)), ret) + t1, t2 = Base.front(I1), Base.front(I2) + # avoid dynamic dispatch by telling the compiler relational invariants + return isa(t1, Tuple{}) ? _isless(newret, (), ()) : _isless(newret, t1, t2::Tuple{Int,Vararg{Int}}) end _isless(ret, ::Tuple{}, ::Tuple{}) = ifelse(ret==1, true, false) icmp(a, b) = ifelse(isless(a,b), 1, ifelse(a==b, 0, -1)) @@ -168,6 +170,7 @@ module IteratorsMD error("iteration is deliberately unsupported for CartesianIndex. Use `I` rather than `I...`, or use `Tuple(I)...`") # Iteration + const ORI = OrdinalRange{Int, Int} """ CartesianIndices(sz::Dims) -> R CartesianIndices((istart:[istep:]istop, jstart:[jstep:]jstop, ...)) -> R @@ -262,7 +265,7 @@ module IteratorsMD For cartesian to linear index conversion, see [`LinearIndices`](@ref). """ - struct CartesianIndices{N,R<:NTuple{N,OrdinalRange{Int, Int}}} <: AbstractArray{CartesianIndex{N},N} + struct CartesianIndices{N,R<:NTuple{N,ORI}} <: AbstractArray{CartesianIndex{N},N} indices::R end @@ -394,19 +397,21 @@ module IteratorsMD # `iterate` returns `Union{Nothing, Tuple}`, we explicitly pass a `valid` flag to eliminate # the type instability inside the core `__inc` logic, and this gives better runtime performance. __inc(::Tuple{}, ::Tuple{}) = false, () - @inline function __inc(state::Tuple{Int}, indices::Tuple{<:OrdinalRange}) + @inline function __inc(state::Tuple{Int}, indices::Tuple{ORI}) rng = indices[1] I = state[1] + step(rng) valid = __is_valid_range(I, rng) && state[1] != last(rng) return valid, (I, ) end - @inline function __inc(state, indices) + @inline function __inc(state::Tuple{Int,Int,Vararg{Int,N}}, indices::Tuple{ORI,ORI,Vararg{ORI,N}}) where {N} rng = indices[1] I = state[1] + step(rng) if __is_valid_range(I, rng) && state[1] != last(rng) return true, (I, tail(state)...) end - valid, I = __inc(tail(state), tail(indices)) + t1, t2 = tail(state), tail(indices) + # avoid dynamic dispatch by telling the compiler relational invariants + valid, I = isa(t1, Tuple{Int}) ? __inc(t1, t2::Tuple{ORI}) : __inc(t1, t2::Tuple{ORI,ORI,Vararg{ORI}}) return valid, (first(rng), I...) end @@ -505,20 +510,21 @@ module IteratorsMD # decrement post check to avoid integer overflow @inline __dec(::Tuple{}, ::Tuple{}) = false, () - @inline function __dec(state::Tuple{Int}, indices::Tuple{<:OrdinalRange}) + @inline function __dec(state::Tuple{Int}, indices::Tuple{ORI}) rng = indices[1] I = state[1] - step(rng) valid = __is_valid_range(I, rng) && state[1] != first(rng) return valid, (I,) end - - @inline function __dec(state, indices) + @inline function __dec(state::Tuple{Int,Int,Vararg{Int,N}}, indices::Tuple{ORI,ORI,Vararg{ORI,N}}) where {N} rng = indices[1] I = state[1] - step(rng) if __is_valid_range(I, rng) && state[1] != first(rng) return true, (I, tail(state)...) end - valid, I = __dec(tail(state), tail(indices)) + t1, t2 = tail(state), tail(indices) + # avoid dynamic dispatch by telling the compiler relational invariants + valid, I = isa(t1, Tuple{Int}) ? __dec(t1, t2::Tuple{ORI}) : __dec(t1, t2::Tuple{ORI,ORI,Vararg{ORI}}) return valid, (last(rng), I...) end diff --git a/base/tuple.jl b/base/tuple.jl index 6d0ab5157f4c7..ad08bf42f5e1a 100644 --- a/base/tuple.jl +++ b/base/tuple.jl @@ -357,10 +357,14 @@ filter(f, t::Any32) = Tuple(filter(f, collect(t))) ## comparison ## -isequal(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _isequal(t1, t2) -_isequal(t1::Tuple{}, t2::Tuple{}) = true -_isequal(t1::Tuple{Any}, t2::Tuple{Any}) = isequal(t1[1], t2[1]) -_isequal(t1::Tuple, t2::Tuple) = isequal(t1[1], t2[1]) && _isequal(tail(t1), tail(t2)) +isequal(t1::Tuple, t2::Tuple) = length(t1) == length(t2) && _isequal(t1, t2) +_isequal(::Tuple{}, ::Tuple{}) = true +function _isequal(t1::Tuple{Any,Vararg{Any,N}}, t2::Tuple{Any,Vararg{Any,N}}) where {N} + isequal(t1[1], t2[1]) || return false + t1, t2 = tail(t1), tail(t2) + # avoid dynamic dispatch by telling the compiler relational invariants + return isa(t1, Tuple{}) ? true : _isequal(t1, t2::Tuple{Any,Vararg{Any}}) +end function _isequal(t1::Any32, t2::Any32) for i = 1:length(t1) if !isequal(t1[i], t2[i])