Skip to content

Commit

Permalink
Make :enter a proper node type (JuliaLang#52300)
Browse files Browse the repository at this point in the history
This is a prepratory commit in anticipation of giving :enter additional
responsibilities of entering and restoring dynamic scopes for
ScopedValue (c.f. JuliaLang#51352). This commit simply turns `:enter` into its
own node type (like the other terminators). The changes are largely
mechanical from the `Expr` version, but will make it easier to add
additional semantics in a follow up PR.
  • Loading branch information
Keno authored and mkitti committed Dec 9, 2023
1 parent 1a7a10d commit aa5e994
Show file tree
Hide file tree
Showing 24 changed files with 191 additions and 174 deletions.
8 changes: 6 additions & 2 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ eval(Core, quote
ReturnNode(@nospecialize val) = $(Expr(:new, :ReturnNode, :val))
ReturnNode() = $(Expr(:new, :ReturnNode)) # unassigned val indicates unreachable
GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest))
EnterNode(dest::Int) = $(Expr(:new, :EnterNode, :dest))
LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing))
function LineNumberNode(l::Int, @nospecialize(f))
isa(f, String) && (f = Symbol(f))
Expand Down Expand Up @@ -626,12 +627,12 @@ module IR
export CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, SlotNumber, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct, InterConditional
Const, PartialStruct, InterConditional, EnterNode

using Core: CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, SlotNumber, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct, InterConditional
Const, PartialStruct, InterConditional, EnterNode

end # module IR

Expand Down Expand Up @@ -965,4 +966,7 @@ arraysize(a::Array) = a.size
arraysize(a::Array, i::Int) = sle_int(i, nfields(a.size)) ? getfield(a.size, i) : 1
export arrayref, arrayset, arraysize, const_arrayref

# For convenience
EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest)

ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true)
11 changes: 6 additions & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3102,8 +3102,8 @@ function update_exc_bestguess!(@nospecialize(exct), frame::InferenceState, 𝕃
handler_frame = frame.handlers[cur_hand]
if !(𝕃ₚ, exct, handler_frame.exct)
handler_frame.exct = tmerge(𝕃ₚ, handler_frame.exct, exct)
enter = frame.src.code[handler_frame.enter_idx]::Expr
exceptbb = block_for_inst(frame.cfg, enter.args[1]::Int)
enter = frame.src.code[handler_frame.enter_idx]::EnterNode
exceptbb = block_for_inst(frame.cfg, enter.catch_dest)
push!(frame.ip, exceptbb)
end
end
Expand All @@ -3114,8 +3114,8 @@ function propagate_to_error_handler!(currstate::VarTable, frame::InferenceState,
# exception handler, BEFORE applying any state changes.
cur_hand = frame.handler_at[frame.currpc][1]
if cur_hand != 0
enter = frame.src.code[frame.handlers[cur_hand].enter_idx]::Expr
exceptbb = block_for_inst(frame.cfg, enter.args[1]::Int)
enter = frame.src.code[frame.handlers[cur_hand].enter_idx]::EnterNode
exceptbb = block_for_inst(frame.cfg, enter.catch_dest)
if update_bbstate!(𝕃ᵢ, frame, exceptbb, currstate)
push!(frame.ip, exceptbb)
end
Expand Down Expand Up @@ -3256,8 +3256,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
ssavaluetypes[frame.currpc] = Any
@goto find_next_bb
elseif isexpr(stmt, :enter)
elseif isa(stmt, EnterNode)
ssavaluetypes[currpc] = Any
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
@goto fallthrough
elseif isexpr(stmt, :leave)
ssavaluetypes[currpc] = Any
Expand Down
20 changes: 10 additions & 10 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
if isa(stmt, EnterNode)
l = stmt.catch_dest
push!(handlers, TryCatchFrame(Bottom, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
Expand Down Expand Up @@ -396,15 +396,15 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
break
elseif isa(stmt, EnterNode)
l = stmt.catch_dest
# We assigned a handler number above. Here we just merge that
# with out current handler information.
handler_at[l] = (cur_stacks[1], handler_at[l][2])
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
l = stmt.args[1]::Int
# We assigned a handler number above. Here we just merge that
# with out current handler information.
handler_at[l] = (cur_stacks[1], handler_at[l][2])
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
elseif head === :leave
if head === :leave
l = 0
for j = 1:length(stmt.args)
arg = stmt.args[j]
Expand All @@ -415,7 +415,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
if enter_stmt === nothing
continue
end
@assert isexpr(enter_stmt, :enter) "malformed :leave"
@assert isa(enter_stmt, EnterNode) "malformed :leave"
end
l += 1
end
Expand Down
37 changes: 19 additions & 18 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
isa(stmt, PiNode) && return (true, true, true)
isa(stmt, PhiNode) && return (true, true, true)
isa(stmt, ReturnNode) && return (true, false, true)
isa(stmt, EnterNode) && return (true, false, true)
isa(stmt, GotoNode) && return (true, false, true)
isa(stmt, GotoIfNot) && return (true, false, (𝕃ₒ, argextype(stmt.cond, src), Bool))
if isa(stmt, GlobalRef)
Expand Down Expand Up @@ -761,7 +762,7 @@ end
function ((; sv)::ScanStmt)(inst::Instruction, lstmt::Int, bb::Int)
stmt = inst[:stmt]

if isexpr(stmt, :enter)
if isa(stmt, EnterNode)
# try/catch not yet modeled
give_up_refinements!(sv)
return nothing
Expand Down Expand Up @@ -971,8 +972,8 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
expr = nothing
end
code[i] = expr
elseif isexpr(expr, :enter)
catchdest = expr.args[1]::Int
elseif isa(expr, EnterNode)
catchdest = expr.catch_dest
if catchdest in sv.unreachable
cfg_delete_edge!(sv.cfg, block_for_inst(sv.cfg, i), block_for_inst(sv.cfg, catchdest))
code[i] = nothing
Expand Down Expand Up @@ -1239,12 +1240,6 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
return cost
elseif head === :copyast
return 100
elseif head === :enter
# try/catch is a couple function calls,
# but don't inline functions with try/catch
# since these aren't usually performance-sensitive functions,
# and llvm is more likely to miscompile them when these functions get large
return typemax(Int)
end
return 0
end
Expand All @@ -1263,6 +1258,12 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{Cod
thiscost = dst(stmt.label) < line ? 40 : 0
elseif stmt isa GotoIfNot
thiscost = dst(stmt.dest) < line ? 40 : 0
elseif stmt isa EnterNode
# try/catch is a couple function calls,
# but don't inline functions with try/catch
# since these aren't usually performance-sensitive functions,
# and llvm is more likely to miscompile them when these functions get large
thiscost = typemax(Int)
end
return thiscost
end
Expand Down Expand Up @@ -1359,19 +1360,19 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
i += 1
end
end
elseif isa(el, EnterNode)
tgt = el.catch_dest
was_deleted = labelchangemap[tgt] == typemin(Int)
if was_deleted
body[i] = nothing
else
body[i] = EnterNode(el, tgt + labelchangemap[tgt])
end
elseif isa(el, Expr)
if el.head === :(=) && el.args[2] isa Expr
el = el.args[2]::Expr
end
if el.head === :enter
tgt = el.args[1]::Int
was_deleted = labelchangemap[tgt] == typemin(Int)
if was_deleted
body[i] = nothing
else
el.args[1] = tgt + labelchangemap[tgt]
end
elseif !is_meta_expr_head(el.head)
if !is_meta_expr_head(el.head)
args = el.args
for i = 1:length(args)
el = args[i]
Expand Down
9 changes: 6 additions & 3 deletions base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ function analyze_escapes(ir::IRCode, nargs::Int, 𝕃ₒ::AbstractLattice, get_e
elseif is_meta_expr_head(head)
# meta expressions doesn't account for any usages
continue
elseif head === :enter || head === :leave || head === :the_exception || head === :pop_exception
elseif head === :leave || head === :the_exception || head === :pop_exception
# ignore these expressions since escapes via exceptions are handled by `escape_exception!`
# `escape_exception!` conservatively propagates `AllEscape` anyway,
# and so escape information imposed on `:the_exception` isn't computed
Expand All @@ -666,6 +666,9 @@ function analyze_escapes(ir::IRCode, nargs::Int, 𝕃ₒ::AbstractLattice, get_e
else
add_conservative_changes!(astate, pc, stmt.args)
end
elseif isa(stmt, EnterNode)
# Handled via escape_exception!
continue
elseif isa(stmt, ReturnNode)
if isdefined(stmt, :val)
add_escape_change!(astate, stmt.val, ReturnEscape(pc))
Expand Down Expand Up @@ -728,10 +731,10 @@ function compute_frameinfo(ir::IRCode)
for idx in 1:nstmts+nnewnodes
inst = ir[SSAValue(idx)]
stmt = inst[:stmt]
if isexpr(stmt, :enter)
if isa(stmt, EnterNode)
@assert idx nstmts "try/catch inside new_nodes unsupported"
tryregions === nothing && (tryregions = UnitRange{Int}[])
leave_block = stmt.args[1]::Int
leave_block = stmt.catch_dest
leave_pc = first(ir.cfg.blocks[leave_block].stmts)
push!(tryregions, idx:leave_pc)
elseif arrayinfo !== nothing
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
elseif isa(stmt′, GotoNode)
stmt′ = GotoNode(stmt′.label + bb_offset)
elseif isa(stmt′, Expr) && stmt′.head === :enter
stmt′ = Expr(:enter, stmt′.args[1]::Int + bb_offset)
elseif isa(stmt′, EnterNode)
stmt′ = EnterNode(stmt′, stmt′.catch_dest + bb_offset)
elseif isa(stmt′, GotoIfNot)
stmt′ = GotoIfNot(stmt′.cond, stmt′.dest + bb_offset)
elseif isa(stmt′, PhiNode)
Expand Down Expand Up @@ -710,8 +710,8 @@ function batch_inline!(ir::IRCode, todo::Vector{Pair{Int,Any}}, propagate_inboun
end
elseif isa(stmt, GotoNode)
compact[idx] = GotoNode(state.bb_rename[stmt.label])
elseif isa(stmt, Expr) && stmt.head === :enter
compact[idx] = Expr(:enter, state.bb_rename[stmt.args[1]::Int])
elseif isa(stmt, EnterNode)
compact[idx] = EnterNode(stmt, state.bb_rename[stmt.catch_dest])
elseif isa(stmt, GotoIfNot)
compact[idx] = GotoIfNot(stmt.cond, state.bb_rename[stmt.dest])
elseif isa(stmt, PhiNode)
Expand Down
41 changes: 25 additions & 16 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Core.PhiNode() = Core.PhiNode(Int32[], Any[])

isterminator(@nospecialize(stmt)) = isa(stmt, GotoNode) || isa(stmt, GotoIfNot) || isa(stmt, ReturnNode)
isterminator(@nospecialize(stmt)) = isa(stmt, GotoNode) || isa(stmt, GotoIfNot) || isa(stmt, ReturnNode) || isa(stmt, EnterNode)

struct CFG
blocks::Vector{BasicBlock}
Expand Down Expand Up @@ -60,16 +60,16 @@ block_for_inst(cfg::CFG, inst::Int) = block_for_inst(cfg.index, inst)
# This is a fake dest to force the next stmt to start a bb
idx < length(stmts) && push!(jump_dests, idx+1)
push!(jump_dests, stmt.label)
elseif isa(stmt, EnterNode)
# :enter starts/ends a BB
push!(jump_dests, idx)
push!(jump_dests, idx+1)
# The catch block is a jump dest
push!(jump_dests, stmt.catch_dest)
elseif isa(stmt, Expr)
if stmt.head === :leave
# :leave terminates a BB
push!(jump_dests, idx+1)
elseif stmt.head === :enter
# :enter starts/ends a BB
push!(jump_dests, idx)
push!(jump_dests, idx+1)
# The catch block is a jump dest
push!(jump_dests, stmt.args[1]::Int)
end
end
if isa(stmt, PhiNode)
Expand Down Expand Up @@ -125,11 +125,11 @@ function compute_basic_blocks(stmts::Vector{Any})
push!(blocks[block′].preds, num)
push!(b.succs, block′)
end
elseif isexpr(terminator, :enter)
elseif isa(terminator, EnterNode)
# :enter gets a virtual edge to the exception handler and
# the exception handler gets a virtual edge from outside
# the function.
block′ = block_for_inst(basic_block_index, terminator.args[1]::Int)
block′ = block_for_inst(basic_block_index, terminator.catch_dest)
push!(blocks[block′].preds, num)
push!(blocks[block′].preds, 0)
push!(b.succs, block′)
Expand Down Expand Up @@ -456,6 +456,10 @@ struct UndefToken end; const UNDEF_TOKEN = UndefToken()
isdefined(stmt, :val) || return OOB_TOKEN
op == 1 || return OOB_TOKEN
return stmt.val
elseif isa(stmt, EnterNode)
isdefined(stmt, :scope) || return OOB_TOKEN
op == 1 || return OOB_TOKEN
return stmt.scope
elseif isa(stmt, PiNode)
isdefined(stmt, :val) || return OOB_TOKEN
op == 1 || return OOB_TOKEN
Expand Down Expand Up @@ -510,6 +514,9 @@ end
elseif isa(stmt, GotoIfNot)
op == 1 || throw(BoundsError())
stmt = GotoIfNot(v, stmt.dest)
elseif isa(stmt, EnterNode)
op == 1 || throw(BoundsError())
stmt = EnterNode(stmt.catch_dest, v)
elseif isa(stmt, ReturnNode)
op == 1 || throw(BoundsError())
stmt = typeof(stmt)(v)
Expand Down Expand Up @@ -544,7 +551,7 @@ end
function userefs(@nospecialize(x))
relevant = (isa(x, Expr) && is_relevant_expr(x)) ||
isa(x, GotoIfNot) || isa(x, ReturnNode) || isa(x, SSAValue) || isa(x, NewSSAValue) ||
isa(x, PiNode) || isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode)
isa(x, PiNode) || isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode) || isa(x, EnterNode)
return UseRefIterator(x, relevant)
end

Expand Down Expand Up @@ -1379,13 +1386,15 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
result[result_idx][:stmt] = GotoIfNot(cond, label)
result_idx += 1
end
elseif cfg_transforms_enabled && isa(stmt, EnterNode)
label = bb_rename_succ[stmt.catch_dest]
@assert label > 0
ssa_rename[idx] = SSAValue(result_idx)
result[result_idx][:stmt] = EnterNode(stmt, label)
result_idx += 1
elseif isa(stmt, Expr)
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa, mark_refined!)::Expr
if cfg_transforms_enabled && isexpr(stmt, :enter)
label = bb_rename_succ[stmt.args[1]::Int]
@assert label > 0
stmt.args[1] = label
elseif isexpr(stmt, :throw_undef_if_not)
if isexpr(stmt, :throw_undef_if_not)
cond = stmt.args[2]
if isa(cond, Bool) && cond === true
# cond was folded to true - this statement
Expand Down Expand Up @@ -1445,7 +1454,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
ssa_rename[idx] = SSAValue(result_idx)
result[result_idx][:stmt] = stmt
result_idx += 1
elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode) || isa(stmt, GotoIfNot)
elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode) || isa(stmt, GotoIfNot) || isa(stmt, EnterNode)
ssa_rename[idx] = SSAValue(result_idx)
result[result_idx][:stmt] = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa, mark_refined!)
result_idx += 1
Expand Down
9 changes: 4 additions & 5 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function kill_terminator_edges!(irsv::IRInterpretationState, term_idx::Int, bb::
elseif isa(stmt, ReturnNode)
# Nothing to do
else
@assert !isexpr(stmt, :enter)
@assert !isa(stmt, EnterNode)
kill_edge!(irsv, bb, bb+1)
end
end
Expand Down Expand Up @@ -222,8 +222,8 @@ function process_terminator!(@nospecialize(stmt), bb::Int, bb_ip::BitSetBoundedM
backedge || push!(bb_ip, stmt.dest)
push!(bb_ip, bb+1)
return backedge
elseif isexpr(stmt, :enter)
dest = stmt.args[1]::Int
elseif isa(stmt, EnterNode)
dest = stmt.catch_dest
@assert dest > bb
push!(bb_ip, dest)
push!(bb_ip, bb+1)
Expand Down Expand Up @@ -329,8 +329,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
delete!(ssa_refined, idx)
end
check_ret!(stmt, idx)
is_terminator_or_phi = (isa(stmt, PhiNode) || isa(stmt, GotoNode) ||
isa(stmt, GotoIfNot) || isa(stmt, ReturnNode) || isexpr(stmt, :enter))
is_terminator_or_phi = (isa(stmt, PhiNode) || isterminator(stmt))
if typ === Bottom && !(idx == lstmt && is_terminator_or_phi)
return true
end
Expand Down
Loading

0 comments on commit aa5e994

Please sign in to comment.