From d28d0fbb6e644c0d37ed2ffdee7e2a51241e23ce Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 28 Aug 2020 17:10:04 -0400 Subject: [PATCH] RFC: Change lowering of destructuring to avoid const prop dependence I'm currently doing some work with inference passes that have const prop (temporarily) disabled and I noticed we actually rely on it quite a bit for basic things. That's not terrible - const prop works pretty well after all, but it still imposes a cost and while I want to support it in my AD use case also, it makes destructuring quite expensive, because everything needs to be inferred twice. This PR is an experiment in changing the lowering to avoid having to const prop the index. Rather than lowering `(a,b,c) = foo()` as: ``` it = foo() a, s = indexed_iterate(it, 1) b, s = indexed_iterate(it, 2) c, s = indexed_iterate(it, 3) ``` we lower as: ``` it = foo() iterate, index = index_and_itereate(it) x = iterate(it) a = index(x, 1) y = iterate(it, y) b = index(y, 2) z = iterate(it, z) c = index(z, 3) ``` For tuples `iterate` would simply return the first argument and `index` would be `getfield`. That way, there is no const prop, since `getfield` is called directly and inference can directly use its tfunc. For the fallback case `iterate` is basically just `Base.iterate`, with just a slight tweak to give an intelligent error for short iterables. On simple functions, there isn't much of a difference in execution time, but benchmarking something more complicated like: ``` function g() a, = getfield(((1,),(2.0,3),("x",),(:x,)), Base.inferencebarrier(1)) nothing end ``` shows about a 20% improvement in end-to-end inference/optimize time, which is substantial. --- base/missing.jl | 1 + base/namedtuple.jl | 2 +- base/pair.jl | 2 +- base/tuple.jl | 46 ++++++++++++++++------- src/common_symbols1.inc | 2 +- src/julia-syntax.scm | 23 ++++++++---- stdlib/Serialization/src/Serialization.jl | 4 +- test/compiler/inference.jl | 5 +-- test/core.jl | 8 ++++ test/dict.jl | 2 - 10 files changed, 64 insertions(+), 31 deletions(-) diff --git a/base/missing.jl b/base/missing.jl index 1d42188a656c01..976e837a548877 100644 --- a/base/missing.jl +++ b/base/missing.jl @@ -69,6 +69,7 @@ convert(::Type{T}, x::T) where {T>:Union{Missing, Nothing}} = x convert(::Type{T}, x) where {T>:Missing} = convert(nonmissingtype_checked(T), x) convert(::Type{T}, x) where {T>:Union{Missing, Nothing}} = convert(nonmissingtype_checked(nonnothingtype_checked(T)), x) +index_and_iterate(::Missing) = throw(MethodError(iterate, (missing,))) # Comparison operators ==(::Missing, ::Missing) = missing diff --git a/base/namedtuple.jl b/base/namedtuple.jl index 669d3b521153f0..023239501040de 100644 --- a/base/namedtuple.jl +++ b/base/namedtuple.jl @@ -112,10 +112,10 @@ firstindex(t::NamedTuple) = 1 lastindex(t::NamedTuple) = nfields(t) getindex(t::NamedTuple, i::Int) = getfield(t, i) getindex(t::NamedTuple, i::Symbol) = getfield(t, i) -indexed_iterate(t::NamedTuple, i::Int, state=1) = (getfield(t, i), i+1) isempty(::NamedTuple{()}) = true isempty(::NamedTuple) = false empty(::NamedTuple) = NamedTuple() +index_and_iterate(t::NamedTuple) = (arg1, getfield) convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names,T}) where {names,T<:Tuple} = nt convert(::Type{NamedTuple{names}}, nt::NamedTuple{names}) where {names} = nt diff --git a/base/pair.jl b/base/pair.jl index 3ce88177787cc2..85013af63ece34 100644 --- a/base/pair.jl +++ b/base/pair.jl @@ -47,7 +47,7 @@ Pair, => eltype(p::Type{Pair{A, B}}) where {A, B} = Union{A, B} iterate(p::Pair, i=1) = i > 2 ? nothing : (getfield(p, i), i + 1) -indexed_iterate(p::Pair, i::Int, state=1) = (getfield(p, i), i + 1) +index_and_iterate(p::Pair) = (arg1, getfield) hash(p::Pair, h::UInt) = hash(p.second, hash(p.first, h)) diff --git a/base/tuple.jl b/base/tuple.jl index f51507b4700c87..6533cd970e9de1 100644 --- a/base/tuple.jl +++ b/base/tuple.jl @@ -81,21 +81,41 @@ function _maxlength(t::Tuple, t2::Tuple, t3::Tuple...) max(length(t), _maxlength(t2, t3...)) end -# this allows partial evaluation of bounded sequences of next() calls on tuples, -# while reducing to plain next() for arbitrary iterables. -indexed_iterate(t::Tuple, i::Int, state=1) = (@_inline_meta; (getfield(t, i), i+1)) -indexed_iterate(a::Array, i::Int, state=1) = (@_inline_meta; (a[i], i+1)) -function indexed_iterate(I, i) - x = iterate(I) - x === nothing && throw(BoundsError(I, i)) - x -end -function indexed_iterate(I, i, state) - x = iterate(I, state) - x === nothing && throw(BoundsError(I, i)) - x +# this allows partial evaluation of bounded sequences of iterate() calls on tuples, +# while reducing to plain iterate() for arbitrary iterables. + +arg1(a) = a +arg1(a, b) = a +index_and_iterate(t::Tuple) = (arg1, getfield) +index_and_iterate(t::Array) = (arg1, getindex) + +struct BadSlurp + a +end + +function slurp_iterate(a) + @_inline_meta + s = iterate(a) + s === nothing && return BadSlurp(a) + s end +function slurp_iterate(a, b) + @_inline_meta + s = iterate(a, getfield(b, 2)) + s === nothing && return BadSlurp(a) + s +end + +select_first(a::BadSlurp, i) = throw(BoundsError(a.a, i)) +select_first(a, i) = getfield(a, 1) + +index_and_iterate(x) = (slurp_iterate, select_first) + +# Nothing is often union'ed into other things. Kill that as quickly as possible +# to make inference's life easier. +index_and_iterate(::Nothing) = throw(MethodError(iterate, (nothing,))) + # Use dispatch to avoid a branch in first first(::Tuple{}) = throw(ArgumentError("tuple must be non-empty")) first(t::Tuple) = t[1] diff --git a/src/common_symbols1.inc b/src/common_symbols1.inc index d035ab76aa6ad8..b9ff427cb56e76 100644 --- a/src/common_symbols1.inc +++ b/src/common_symbols1.inc @@ -33,7 +33,7 @@ jl_symbol("*"), jl_symbol("bitcast"), jl_symbol("slt_int"), jl_symbol("isempty"), -jl_symbol("indexed_iterate"), +jl_symbol("index_and_iterate"), jl_symbol("size"), jl_symbol("!"), jl_symbol("nothing"), diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 0d73ef4665a38c..fa84bf8b539139 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -2049,17 +2049,26 @@ x (make-ssavalue))) (ini (if (eq? x xx) '() (list (sink-assignment xx (expand-forms x))))) (n (length lhss)) + (funcs (make-ssavalue)) + (iterate (make-ssavalue)) + (index (make-ssavalue)) (st (gensy))) `(block ,@ini + ,(lower-tuple-assignment + (list iterate index) + `(call (top index_and_iterate) ,xx)) ,.(map (lambda (i lhs) - (expand-forms - (lower-tuple-assignment - (if (= i (- n 1)) - (list lhs) - (list lhs st)) - `(call (top indexed_iterate) - ,xx ,(+ i 1) ,.(if (eq? i 0) '() `(,st)))))) + (expand-forms + `(block + (= ,st (call ,iterate + ,xx ,.(if (eq? i 0) '() `(,st)))) + ,(if (eventually-call? lhs) + (let ((val (gensy))) + `(block + (= ,val (call ,index ,st ,(+ i 1))) + (= ,lhs ,val))) + `(= ,lhs (call ,index ,st ,(+ i 1))))))) (iota n) lhss) (unnecessary ,xx)))))) diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index 5fd532373dc8e2..ee7e08afec12a7 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -66,7 +66,7 @@ const TAGS = Any[ :(=), :(==), :(===), :gotoifnot, :A, :B, :C, :M, :N, :T, :S, :X, :Y, :a, :b, :c, :d, :e, :f, :g, :h, :i, :j, :k, :l, :m, :n, :o, :p, :q, :r, :s, :t, :u, :v, :w, :x, :y, :z, :add_int, :sub_int, :mul_int, :add_float, :sub_float, :new, :mul_float, :bitcast, :start, :done, :next, - :indexed_iterate, :getfield, :meta, :eq_int, :slt_int, :sle_int, :ne_int, :push_loc, :pop_loc, + :index_and_iterate, :getfield, :meta, :eq_int, :slt_int, :sle_int, :ne_int, :push_loc, :pop_loc, :pop, :arrayset, :arrayref, :apply_type, :inbounds, :getindex, :setindex!, :Core, :!, :+, :Base, :static_parameter, :convert, :colon, Symbol("#self#"), Symbol("#temp#"), :tuple, Symbol(""), @@ -78,7 +78,7 @@ const TAGS = Any[ @assert length(TAGS) == 255 -const ser_version = 11 # do not make changes without bumping the version #! +const ser_version = 12 # do not make changes without bumping the version #! const NTAGS = length(TAGS) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index d79898985317e1..46046d65795773 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -562,10 +562,7 @@ end function g19348(x) a, b = x - g = 1 - g = 2 - c = Base.indexed_iterate(x, g, g) - return a + b + c[1] + return a + b end for (codetype, all_ssa) in Any[ diff --git a/test/core.jl b/test/core.jl index 19f50d9671f959..0b76432d73eaf4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -557,6 +557,14 @@ let (f(), x) = (1, 2) @test x == 2 end +foo23091_cnt = 0 +struct Foo23091; end +Base.iterate(::Foo23091, state...) = (global foo23091_cnt += 1; (1, nothing)) +(g23091(), h23091()) = Foo23091() +@test foo23091_cnt == 2 +g23091(); h23091() +@test foo23091_cnt == 2 + # issue #21900 f21900_cnt = 0 function f21900() diff --git a/test/dict.jl b/test/dict.jl index de455576b2bc4f..770b889e43cd16 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -11,8 +11,6 @@ using Random @test iterate(p, iterate(p, iterate(p)[2])[2]) == nothing @test firstindex(p) == 1 @test lastindex(p) == length(p) == 2 - @test Base.indexed_iterate(p, 1, nothing) == (10,2) - @test Base.indexed_iterate(p, 2, nothing) == (20,3) @test (1=>2) < (2=>3) @test (2=>2) < (2=>3) @test !((2=>3) < (2=>3))