Skip to content

Commit

Permalink
optimizer: use count checking framework (#44794)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Atol authored May 17, 2022
1 parent eb4c757 commit bad3e39
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 78 deletions.
3 changes: 1 addition & 2 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1379,9 +1379,7 @@ function inline_const_if_inlineable!(inst::Instruction)
end

function assemble_inline_todo!(ir::IRCode, state::InliningState)
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
todo = Pair{Int, Any}[]
et = state.et

for idx in 1:length(ir.stmts)
simpleres = process_simple!(ir, idx, state, todo)
Expand Down Expand Up @@ -1586,6 +1584,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
end
end
end
isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop
urs = userefs(val)
for op in urs
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
Expand Down
104 changes: 41 additions & 63 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ struct UndefToken end; const UNDEF_TOKEN = UndefToken()
isdefined(stmt, :val) || return OOB_TOKEN
op == 1 || return OOB_TOKEN
return stmt.val
elseif isa(stmt, Union{SSAValue, NewSSAValue})
op == 1 || return OOB_TOKEN
return stmt
elseif isa(stmt, UpsilonNode)
isdefined(stmt, :val) || return OOB_TOKEN
op == 1 || return OOB_TOKEN
Expand Down Expand Up @@ -430,6 +433,9 @@ end
elseif isa(stmt, ReturnNode)
op == 1 || throw(BoundsError())
stmt = typeof(stmt)(v)
elseif isa(stmt, Union{SSAValue, NewSSAValue})
op == 1 || throw(BoundsError())
stmt = v
elseif isa(stmt, UpsilonNode)
op == 1 || throw(BoundsError())
stmt = typeof(stmt)(v)
Expand Down Expand Up @@ -457,7 +463,7 @@ end

function userefs(@nospecialize(x))
relevant = (isa(x, Expr) && is_relevant_expr(x)) ||
isa(x, GotoIfNot) || isa(x, ReturnNode) ||
isa(x, GotoIfNot) || isa(x, ReturnNode) || isa(x, SSAValue) || isa(x, NewSSAValue) ||
isa(x, PiNode) || isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode)
return UseRefIterator(x, relevant)
end
Expand All @@ -480,50 +486,10 @@ end

# This function is used from the show code, which may have a different
# `push!`/`used` type since it's in Base.
function scan_ssa_use!(push!, used, @nospecialize(stmt))
if isa(stmt, SSAValue)
push!(used, stmt.id)
end
for useref in userefs(stmt)
val = useref[]
if isa(val, SSAValue)
push!(used, val.id)
end
end
end
scan_ssa_use!(push!, used, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt)

# Manually specialized copy of the above with push! === Compiler.push!
function scan_ssa_use!(used::IdSet, @nospecialize(stmt))
if isa(stmt, SSAValue)
push!(used, stmt.id)
end
for useref in userefs(stmt)
val = useref[]
if isa(val, SSAValue)
push!(used, val.id)
end
end
end

function ssamap(f, @nospecialize(stmt))
urs = userefs(stmt)
for op in urs
val = op[]
if isa(val, SSAValue)
op[] = f(val)
end
end
return urs[]
end

function foreachssa(f, @nospecialize(stmt))
for op in userefs(stmt)
val = op[]
if isa(val, SSAValue)
f(val)
end
end
end
scan_ssa_use!(used::IdSet, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt)

function insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after::Bool=false)
node = add!(ir.new_nodes, pos, attach_after)
Expand Down Expand Up @@ -751,20 +717,13 @@ end

function count_added_node!(compact::IncrementalCompact, @nospecialize(v))
needs_late_fixup = false
if isa(v, SSAValue)
compact.used_ssas[v.id] += 1
elseif isa(v, NewSSAValue)
compact.new_new_used_ssas[v.id] += 1
needs_late_fixup = true
else
for ops in userefs(v)
val = ops[]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
elseif isa(val, NewSSAValue)
compact.new_new_used_ssas[val.id] += 1
needs_late_fixup = true
end
for ops in userefs(v)
val = ops[]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
elseif isa(val, NewSSAValue)
compact.new_new_used_ssas[val.id] += 1
needs_late_fixup = true
end
end
return needs_late_fixup
Expand Down Expand Up @@ -931,6 +890,27 @@ function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int)
return compact
end

__set_check_ssa_counts(onoff::Bool) = __check_ssa_counts__[] = onoff
const __check_ssa_counts__ = fill(false)

function _oracle_check(compact::IncrementalCompact)
observed_used_ssas = Core.Compiler.find_ssavalue_uses1(compact)
for i = 1:length(observed_used_ssas)
if observed_used_ssas[i] != compact.used_ssas[i]
return observed_used_ssas
end
end
return nothing
end

function oracle_check(compact::IncrementalCompact)
maybe_oracle_used_ssas = _oracle_check(compact)
if maybe_oracle_used_ssas !== nothing
@eval Main (compact = $compact; oracle_used_ssas = $maybe_oracle_used_ssas)
error("Oracle check failed, inspect Main.compact and Main.oracle_used_ssas")
end
end

getindex(view::TypesView, idx::SSAValue) = getindex(view, idx.id)
function getindex(view::TypesView, idx::Int)
if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx
Expand Down Expand Up @@ -1425,7 +1405,6 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
# result_idx is not, incremented, but that's ok and expected
compact.result[old_result_idx] = compact.ir.stmts[idx]
result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx, active_bb, true)
stmt_if_any = old_result_idx == result_idx ? nothing : compact.result[old_result_idx][:inst]
compact.result_idx = result_idx
if idx == last(bb.stmts) && !attach_after_stmt_after(compact, idx)
finish_current_bb!(compact, active_bb, old_result_idx)
Expand Down Expand Up @@ -1464,11 +1443,7 @@ function maybe_erase_unused!(
callback(val)
end
if effect_free
if isa(stmt, SSAValue)
kill_ssa_value(stmt)
else
foreachssa(kill_ssa_value, stmt)
end
foreachssa(kill_ssa_value, stmt)
inst[:inst] = nothing
return true
end
Expand Down Expand Up @@ -1570,6 +1545,9 @@ end
function complete(compact::IncrementalCompact)
result_bbs = resize!(compact.result_bbs, compact.active_result_bb-1)
cfg = CFG(result_bbs, Int[first(result_bbs[i].stmts) for i in 2:length(result_bbs)])
if __check_ssa_counts__[]
oracle_check(compact)
end
return IRCode(compact.ir, compact.result, cfg, compact.new_new_nodes)
end

Expand Down
9 changes: 0 additions & 9 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1151,15 +1151,6 @@ function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact
end
end

function count_uses(@nospecialize(stmt), uses::Vector{Int})
for ur in userefs(stmt)
use = ur[]
if isa(use, SSAValue)
uses[use.id] += 1
end
end
end

function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::SPCSet, phi::Int)
worklist = Int[]
push!(worklist, phi)
Expand Down
3 changes: 0 additions & 3 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ function make_ssa!(ci::CodeInfo, code::Vector{Any}, idx, slot, @nospecialize(typ
end

function new_to_regular(@nospecialize(stmt), new_offset::Int)
if isa(stmt, NewSSAValue)
return SSAValue(stmt.id + new_offset)
end
urs = userefs(stmt)
for op in urs
val = op[]
Expand Down
53 changes: 53 additions & 0 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,27 @@ end
# SSAValues/Slots #
###################

function ssamap(f, @nospecialize(stmt))
urs = userefs(stmt)
for op in urs
val = op[]
if isa(val, SSAValue)
op[] = f(val)
end
end
return urs[]
end

function foreachssa(f, @nospecialize(stmt))
urs = userefs(stmt)
for op in urs
val = op[]
if isa(val, SSAValue)
f(val)
end
end
end

function find_ssavalue_uses(body::Vector{Any}, nvals::Int)
uses = BitSet[ BitSet() for i = 1:nvals ]
for line in 1:length(body)
Expand Down Expand Up @@ -333,6 +354,38 @@ end
@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id :
isa(s, Argument) ? (s::Argument).n : (s::TypedSlot).id

######################
# IncrementalCompact #
######################

# specifically meant to be used with body1 = compact.result and body2 = compact.new_new_nodes, with nvals == length(compact.used_ssas)
function find_ssavalue_uses1(compact)
body1, body2 = compact.result.inst, compact.new_new_nodes.stmts.inst
nvals = length(compact.used_ssas)
nbody1 = length(body1)
nbody2 = length(body2)

uses = zeros(Int, nvals)
function increment_uses(ssa::SSAValue)
uses[ssa.id] += 1
end

for line in 1:(nbody1 + nbody2)
# index into the right body
if line <= nbody1
isassigned(body1, line) || continue
e = body1[line]
else
line -= nbody1
isassigned(body2, line) || continue
e = body2[line]
end

foreachssa(increment_uses, e)
end
return uses
end

###########
# options #
###########
Expand Down
65 changes: 64 additions & 1 deletion test/compiler/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
using Base.Meta
using Core.IR
const Compiler = Core.Compiler
using .Compiler: CFG, BasicBlock
using .Compiler: CFG, BasicBlock, NewSSAValue

make_bb(preds, succs) = BasicBlock(Compiler.StmtRange(0, 0), preds, succs)

Expand Down Expand Up @@ -334,3 +334,66 @@ f_if_typecheck() = (if nothing; end; unsafe_load(Ptr{Int}(0)))
stderr = IOBuffer()
success(pipeline(Cmd(cmd); stdout=stdout, stderr=stderr)) && isempty(String(take!(stderr)))
end

let
function test_useref(stmt, v, op)
if isa(stmt, Expr)
@test stmt.args[op] === v
elseif isa(stmt, GotoIfNot)
@test stmt.cond === v
elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode)
@test stmt.val === v
elseif isa(stmt, SSAValue) || isa(stmt, NewSSAValue)
@test stmt === v
elseif isa(stmt, PiNode)
@test stmt.val === v && stmt.typ === typeof(stmt)
elseif isa(stmt, PhiNode) || isa(stmt, PhiCNode)
@test stmt.values[op] === v
end
end

function _test_userefs(@nospecialize stmt)
ex = Expr(:call, :+, Core.SSAValue(3), 1)
urs = Core.Compiler.userefs(stmt)::Core.Compiler.UseRefIterator
it = Core.Compiler.iterate(urs)
while it !== nothing
ur = getfield(it, 1)::Core.Compiler.UseRef
op = getfield(it, 2)::Int
v1 = Core.Compiler.getindex(ur)
# set to dummy expression and then back to itself to test `_useref_setindex!`
v2 = Core.Compiler.setindex!(ur, ex)
test_useref(v2, ex, op)
Core.Compiler.setindex!(ur, v1)
@test Core.Compiler.getindex(ur) === v1
it = Core.Compiler.iterate(urs, op)
end
end

function test_userefs(body)
for stmt in body
_test_userefs(stmt)
end
end

# this isn't valid code, we just care about looking at a variety of IR nodes
body = Any[
Expr(:enter, 11),
Expr(:call, :+, SSAValue(3), 1),
Expr(:throw_undef_if_not, :expected, false),
Expr(:leave, 1),
Expr(:(=), SSAValue(1), Expr(:call, :+, SSAValue(3), 1)),
UpsilonNode(),
UpsilonNode(SSAValue(2)),
PhiCNode(Any[SSAValue(5), SSAValue(7), SSAValue(9)]),
PhiCNode(Any[SSAValue(6)]),
PhiNode(Int32[8], Any[SSAValue(7)]),
PiNode(SSAValue(6), GotoNode),
GotoIfNot(SSAValue(3), 10),
GotoNode(5),
SSAValue(7),
NewSSAValue(9),
ReturnNode(SSAValue(11)),
]

test_userefs(body)
end

0 comments on commit bad3e39

Please sign in to comment.