Skip to content

Commit

Permalink
RFC: Change lowering of destructuring to avoid const prop dependence
Browse files Browse the repository at this point in the history
I'm currently doing some work with inference passes that have const
prop (temporarily) disabled and I noticed we actually rely on it
quite a bit for basic things. That's not terrible - const prop works
pretty well after all, but it still imposes a cost and while I want
to support it in my AD use case also, it makes destructuring quite
expensive, because everything needs to be inferred twice. This PR
is an experiment in changing the lowering to avoid having to const
prop the index. Rather than lowering `(a,b,c) = foo()` as:

```
it = foo()
a, s = indexed_iterate(it, 1)
b, s = indexed_iterate(it, 2)
c, s = indexed_iterate(it, 3)
```

we lower as:

```
it = foo()
iterate, index = index_and_itereate(it)
x = iterate(it)
a = index(x, 1)
y = iterate(it, y)
b = index(y, 2)
z = iterate(it, z)
c = index(z, 3)
```

For tuples `iterate` would simply return the first argument and
`index` would be `getfield`. That way, there is no const prop,
since `getfield` is called directly and inference can directly
use its tfunc. For the fallback case `iterate` is basically
just `Base.iterate`, with just a slight tweak to give an intelligent
error for short iterables.

On simple functions, there isn't much of a difference in execution
time, but benchmarking something more complicated like:
```
function g()
  a, = getfield(((1,),(2.0,3),("x",),(:x,)), Base.inferencebarrier(1))
  nothing
end
```

shows about a 20% improvement in end-to-end inference/optimize time,
which is substantial.
  • Loading branch information
Keno committed Aug 30, 2020
1 parent 6de97d5 commit d28d0fb
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 31 deletions.
1 change: 1 addition & 0 deletions base/missing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ convert(::Type{T}, x::T) where {T>:Union{Missing, Nothing}} = x
convert(::Type{T}, x) where {T>:Missing} = convert(nonmissingtype_checked(T), x)
convert(::Type{T}, x) where {T>:Union{Missing, Nothing}} = convert(nonmissingtype_checked(nonnothingtype_checked(T)), x)

index_and_iterate(::Missing) = throw(MethodError(iterate, (missing,)))

# Comparison operators
==(::Missing, ::Missing) = missing
Expand Down
2 changes: 1 addition & 1 deletion base/namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ firstindex(t::NamedTuple) = 1
lastindex(t::NamedTuple) = nfields(t)
getindex(t::NamedTuple, i::Int) = getfield(t, i)
getindex(t::NamedTuple, i::Symbol) = getfield(t, i)
indexed_iterate(t::NamedTuple, i::Int, state=1) = (getfield(t, i), i+1)
isempty(::NamedTuple{()}) = true
isempty(::NamedTuple) = false
empty(::NamedTuple) = NamedTuple()
index_and_iterate(t::NamedTuple) = (arg1, getfield)

convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names,T}) where {names,T<:Tuple} = nt
convert(::Type{NamedTuple{names}}, nt::NamedTuple{names}) where {names} = nt
Expand Down
2 changes: 1 addition & 1 deletion base/pair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Pair, =>

eltype(p::Type{Pair{A, B}}) where {A, B} = Union{A, B}
iterate(p::Pair, i=1) = i > 2 ? nothing : (getfield(p, i), i + 1)
indexed_iterate(p::Pair, i::Int, state=1) = (getfield(p, i), i + 1)
index_and_iterate(p::Pair) = (arg1, getfield)

hash(p::Pair, h::UInt) = hash(p.second, hash(p.first, h))

Expand Down
46 changes: 33 additions & 13 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,41 @@ function _maxlength(t::Tuple, t2::Tuple, t3::Tuple...)
max(length(t), _maxlength(t2, t3...))
end

# this allows partial evaluation of bounded sequences of next() calls on tuples,
# while reducing to plain next() for arbitrary iterables.
indexed_iterate(t::Tuple, i::Int, state=1) = (@_inline_meta; (getfield(t, i), i+1))
indexed_iterate(a::Array, i::Int, state=1) = (@_inline_meta; (a[i], i+1))
function indexed_iterate(I, i)
x = iterate(I)
x === nothing && throw(BoundsError(I, i))
x
end
function indexed_iterate(I, i, state)
x = iterate(I, state)
x === nothing && throw(BoundsError(I, i))
x
# this allows partial evaluation of bounded sequences of iterate() calls on tuples,
# while reducing to plain iterate() for arbitrary iterables.

arg1(a) = a
arg1(a, b) = a
index_and_iterate(t::Tuple) = (arg1, getfield)
index_and_iterate(t::Array) = (arg1, getindex)

struct BadSlurp
a
end

function slurp_iterate(a)
@_inline_meta
s = iterate(a)
s === nothing && return BadSlurp(a)
s
end

function slurp_iterate(a, b)
@_inline_meta
s = iterate(a, getfield(b, 2))
s === nothing && return BadSlurp(a)
s
end

select_first(a::BadSlurp, i) = throw(BoundsError(a.a, i))
select_first(a, i) = getfield(a, 1)

index_and_iterate(x) = (slurp_iterate, select_first)

# Nothing is often union'ed into other things. Kill that as quickly as possible
# to make inference's life easier.
index_and_iterate(::Nothing) = throw(MethodError(iterate, (nothing,)))

# Use dispatch to avoid a branch in first
first(::Tuple{}) = throw(ArgumentError("tuple must be non-empty"))
first(t::Tuple) = t[1]
Expand Down
2 changes: 1 addition & 1 deletion src/common_symbols1.inc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jl_symbol("*"),
jl_symbol("bitcast"),
jl_symbol("slt_int"),
jl_symbol("isempty"),
jl_symbol("indexed_iterate"),
jl_symbol("index_and_iterate"),
jl_symbol("size"),
jl_symbol("!"),
jl_symbol("nothing"),
Expand Down
23 changes: 16 additions & 7 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -2049,17 +2049,26 @@
x (make-ssavalue)))
(ini (if (eq? x xx) '() (list (sink-assignment xx (expand-forms x)))))
(n (length lhss))
(funcs (make-ssavalue))
(iterate (make-ssavalue))
(index (make-ssavalue))
(st (gensy)))
`(block
,@ini
,(lower-tuple-assignment
(list iterate index)
`(call (top index_and_iterate) ,xx))
,.(map (lambda (i lhs)
(expand-forms
(lower-tuple-assignment
(if (= i (- n 1))
(list lhs)
(list lhs st))
`(call (top indexed_iterate)
,xx ,(+ i 1) ,.(if (eq? i 0) '() `(,st))))))
(expand-forms
`(block
(= ,st (call ,iterate
,xx ,.(if (eq? i 0) '() `(,st))))
,(if (eventually-call? lhs)
(let ((val (gensy)))
`(block
(= ,val (call ,index ,st ,(+ i 1)))
(= ,lhs ,val)))
`(= ,lhs (call ,index ,st ,(+ i 1)))))))
(iota n)
lhss)
(unnecessary ,xx))))))
Expand Down
4 changes: 2 additions & 2 deletions stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const TAGS = Any[
:(=), :(==), :(===), :gotoifnot, :A, :B, :C, :M, :N, :T, :S, :X, :Y, :a, :b, :c, :d, :e, :f,
:g, :h, :i, :j, :k, :l, :m, :n, :o, :p, :q, :r, :s, :t, :u, :v, :w, :x, :y, :z, :add_int,
:sub_int, :mul_int, :add_float, :sub_float, :new, :mul_float, :bitcast, :start, :done, :next,
:indexed_iterate, :getfield, :meta, :eq_int, :slt_int, :sle_int, :ne_int, :push_loc, :pop_loc,
:index_and_iterate, :getfield, :meta, :eq_int, :slt_int, :sle_int, :ne_int, :push_loc, :pop_loc,
:pop, :arrayset, :arrayref, :apply_type, :inbounds, :getindex, :setindex!, :Core, :!, :+,
:Base, :static_parameter, :convert, :colon, Symbol("#self#"), Symbol("#temp#"), :tuple, Symbol(""),

Expand All @@ -78,7 +78,7 @@ const TAGS = Any[

@assert length(TAGS) == 255

const ser_version = 11 # do not make changes without bumping the version #!
const ser_version = 12 # do not make changes without bumping the version #!

const NTAGS = length(TAGS)

Expand Down
5 changes: 1 addition & 4 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,7 @@ end

function g19348(x)
a, b = x
g = 1
g = 2
c = Base.indexed_iterate(x, g, g)
return a + b + c[1]
return a + b
end

for (codetype, all_ssa) in Any[
Expand Down
8 changes: 8 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,14 @@ let (f(), x) = (1, 2)
@test x == 2
end

foo23091_cnt = 0
struct Foo23091; end
Base.iterate(::Foo23091, state...) = (global foo23091_cnt += 1; (1, nothing))
(g23091(), h23091()) = Foo23091()
@test foo23091_cnt == 2
g23091(); h23091()
@test foo23091_cnt == 2

# issue #21900
f21900_cnt = 0
function f21900()
Expand Down
2 changes: 0 additions & 2 deletions test/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ using Random
@test iterate(p, iterate(p, iterate(p)[2])[2]) == nothing
@test firstindex(p) == 1
@test lastindex(p) == length(p) == 2
@test Base.indexed_iterate(p, 1, nothing) == (10,2)
@test Base.indexed_iterate(p, 2, nothing) == (20,3)
@test (1=>2) < (2=>3)
@test (2=>2) < (2=>3)
@test !((2=>3) < (2=>3))
Expand Down

0 comments on commit d28d0fb

Please sign in to comment.