Skip to content

Commit

Permalink
Some getfield related fixes in inference
Browse files Browse the repository at this point in the history
* Use the field index passed in in `lift_leaves`

    The caller has already done all the computation including bound checking.
    The `field` computed in this function is also affecting all the following iterations
    which is almost certainly wrong.

* Remove unnecessary type check on `field` in `lift_leaves` since it is always `Int`

* Move a branch disabling `return nothing` higher up

* Remove some duplicated calculation on field index in `getfield_elim_pass!`

* Fix `try_compute_fieldidx` to return `nothing` for non-`Int` `Integer` field index.

    This can cause `getfield_nothrow` to return incorrect result.
    It also gives the caller worse type info about the return value.

* Teach `getfield_nothrow` that `isbits` field cannot be undefined and getfield on such field cannot throw.

    This is already handled in `isdefined_tfunc`.

* Fix a few wrong use of `isbits` in dead branches

----

Ref #26948 (fa02d34)
Ref #27126 (9100329)
  • Loading branch information
yuyichao committed Sep 6, 2020
1 parent 483b637 commit 8a6dae4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
18 changes: 8 additions & 10 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
else
def = compact[leaf]
end
if is_tuple_call(compact, def) && isa(field, Int) && 1 <= field < length(def.args)
if is_tuple_call(compact, def) && 1 <= field < length(def.args)
lifted = def.args[1+field]
if is_old(compact, leaf) && isa(lifted, SSAValue)
lifted = OldSSAValue(lifted.id)
Expand All @@ -307,8 +307,6 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
end
(isa(typ, DataType) && (!typ.abstract)) || return nothing
@assert !typ.mutable
field = try_compute_fieldidx_expr(typ, stmt)
field === nothing && return nothing
if length(def.args) < 1 + field
ftyp = fieldtype(typ, field)
if !isbitstype(ftyp)
Expand All @@ -323,7 +321,7 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
compact[leaf] = nothing
for i = (length(def.args) + 1):(1+field)
ftyp = fieldtype(typ, i - 1)
isbits(ftyp) || return nothing
isbitstype(ftyp) || return nothing
push!(def.args, insert_node!(compact, leaf, result_t, Expr(:new, ftyp)))
end
compact[leaf] = def
Expand All @@ -342,22 +340,22 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
else
typ = compact_exprtype(compact, leaf)
if !isa(typ, Const)
# Disabled since #27126
return nothing
# If the leaf is an old ssa value, insert a getfield here
# We will revisit this getfield later when compaction gets
# to the appropriate point.
# N.B.: This can be a bit dangerous because it can lead to
# infinite loops if we accidentally insert a node just ahead
# of where we are
if is_old(compact, leaf) && (isa(field, Int) || isa(field, Symbol))
if is_old(compact, leaf)
(isa(typ, DataType) && (!typ.abstract)) || return nothing
@assert !typ.mutable
# If there's the potential for an undefref error on access, we cannot insert a getfield
if field > typ.ninitialized && !isbits(fieldtype(typ, field))
return nothing
if field > typ.ninitialized && !isbitstype(fieldtype(typ, field))
lifted_leaves[leaf] = RefValue{Any}(insert_node!(compact, leaf, make_MaybeUndef(result_t), Expr(:call, :unchecked_getfield, SSAValue(leaf.id), field), true))
maybe_undef = true
else
return nothing
lifted_leaves[leaf] = RefValue{Any}(insert_node!(compact, leaf, result_t, Expr(:call, getfield, SSAValue(leaf.id), field), true))
end
continue
Expand Down Expand Up @@ -671,7 +669,7 @@ function getfield_elim_pass!(ir::IRCode)

isempty(leaves) && continue

field = try_compute_fieldidx_expr(struct_typ, stmt)
field = try_compute_fieldidx(struct_typ, field)
field === nothing && continue

r = lift_leaves(compact, stmt, result_t, field, leaves)
Expand Down Expand Up @@ -806,7 +804,7 @@ function getfield_elim_pass!(ir::IRCode)
for stmt in du.uses
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
end
if !isbitstype(fieldtype(typ, fidx))
if !isbitstype(ftyp)
for (use, list) in preserve_uses
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
end
Expand Down
8 changes: 6 additions & 2 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,8 @@ function try_compute_fieldidx(typ::DataType, @nospecialize(field))
if isa(field, Symbol)
field = fieldindex(typ, field, false)
field == 0 && return nothing
elseif isa(field, Integer)
elseif isa(field, Int)
# Numerical field name can only be of type `Int`
max_fields = fieldcount_noerror(typ)
max_fields === nothing && return nothing
(1 <= field <= max_fields) || return nothing
Expand Down Expand Up @@ -706,7 +707,8 @@ function getfield_nothrow(@nospecialize(s00), @nospecialize(name), @nospecialize
return false
end

s = unwrap_unionall(widenconst(s00))
s0 = widenconst(s00)
s = unwrap_unionall(s0)
if isa(s, Union)
return getfield_nothrow(rewrap(s.a, s00), name, inbounds) &&
getfield_nothrow(rewrap(s.b, s00), name, inbounds)
Expand All @@ -723,6 +725,8 @@ function getfield_nothrow(@nospecialize(s00), @nospecialize(name), @nospecialize
field = try_compute_fieldidx(s, name.val)
field === nothing && return false
field <= s.ninitialized && return true
# `try_compute_fieldidx` already check for field index bound.
!isvatuple(s) && isbitstype(fieldtype(s0, field)) && return true
end

return false
Expand Down
10 changes: 10 additions & 0 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,16 @@ let K = rand(2,2)
@test test_29253(K) == 2
end

function no_op_refint(r)
r[]
return
end
let code = code_typed(no_op_refint,Tuple{Base.RefValue{Int}})[1].first.code
@test length(code) == 1
@test isa(code[1], Core.ReturnNode)
@test code[1].val === nothing
end

# check getfield elim handling of GlobalRef
const _some_coeffs = (1,[2],3,4)
splat_from_globalref(x) = (x, _some_coeffs...,)
Expand Down

0 comments on commit 8a6dae4

Please sign in to comment.