From 9100329d189b3d2ac6d1e03bfc7431cbe8d18c00 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Tue, 10 Apr 2018 20:07:28 -0400 Subject: [PATCH] Introduce improved SROA pass This is a rebased and fixed version of the improved SROA pass from #26778. There's a decent piece of new infrastructure wrapped up in this: The ability to insert new nodes during compaction. This is a bit tricky because it requires tracking which version of the statements buffer a given SSAValue belongs to. At the moment this is done mostly manually, but I'm hoping to clean that up in the future. The idea of the new SROA pass is fairly straightforward: Given a use of an interesting value, it traces through all phi nodes, finding all leaves, applies whatever transformation to those leaves and then re-inserts a phi nest corresponding to the phi nest of the original value. --- base/compiler/ssair/driver.jl | 10 +- base/compiler/ssair/ir.jl | 341 ++++++++++++++--- base/compiler/ssair/passes.jl | 644 +++++++++++++++++++++++++++------ base/compiler/ssair/queries.jl | 2 +- base/compiler/ssair/show.jl | 32 +- base/compiler/ssair/verify.jl | 6 +- base/compiler/tfuncs.jl | 74 +++- src/julia-syntax.scm | 8 +- 8 files changed, 909 insertions(+), 208 deletions(-) diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 18dfb82161653..d0fbaa11d6bc2 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -17,7 +17,7 @@ include("compiler/ssair/passes.jl") include("compiler/ssair/inlining2.jl") include("compiler/ssair/verify.jl") include("compiler/ssair/legacy.jl") -@isdefined(Base) && include("compiler/ssair/show.jl") +#@isdefined(Base) && include("compiler/ssair/show.jl") function normalize_expr(stmt::Expr) if stmt.head === :gotoifnot @@ -159,14 +159,20 @@ end function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode}, sv::OptimizationState) ir = just_construct_ssa(ci, copy(ci.code), nargs, linetable) + #@Base.show ("after_construct", ir) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) #@timeit "verify 1" verify_ir(ir) @timeit "Inlining" ir = ssa_inlining_pass!(ir, linetable, sv) #@timeit "verify 2" verify_ir(ir) @timeit "domtree 2" domtree = construct_domtree(ir.cfg) + ir = compact!(ir) + #@Base.show ("before_sroa", ir) @timeit "SROA" ir = getfield_elim_pass!(ir, domtree) - @timeit "compact 2" ir = compact!(ir) + #@Base.show ir.new_nodes + #@Base.show ("after_sroa", ir) + ir = adce_pass!(ir) + #@Base.show ("after_adce", ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) #@Base.show ir diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 97a202855a1f6..0b92e5f0b615f 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -413,6 +413,12 @@ mutable struct IncrementalCompact # This could be Stateful, but bootstrapping doesn't like that perm::Vector{Int} new_nodes_idx::Int + # This supports insertion while compacting + new_new_nodes::Vector{NewNode} # New nodes that were before the compaction point at insertion time + # TODO: Switch these two to a min-heap of some sort + pending_nodes::Vector{NewNode} # New nodes that were after the compaction point at insertion time + pending_perm::Vector{Int} + # State idx::Int result_idx::Int active_result_bb::Int @@ -427,7 +433,12 @@ mutable struct IncrementalCompact used_ssas = fill(0, new_len) ssa_rename = Any[SSAValue(i) for i = 1:new_len] late_fixup = Vector{Int}() - return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1, 1, 1, 1) + new_new_nodes = NewNode[] + pending_nodes = NewNode[] + pending_perm = Int[] + return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1, + new_new_nodes, pending_nodes, pending_perm, + 1, 1, 1) end # For inlining @@ -437,8 +448,14 @@ mutable struct IncrementalCompact ssa_rename = Any[SSAValue(i) for i = 1:new_len] used_ssas = fill(0, new_len) late_fixup = Vector{Int}() - return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, parent.result_bbs, - ssa_rename, parent.used_ssas, late_fixup, perm, 1, 1, result_offset, parent.active_result_bb) + new_new_nodes = NewNode[] + pending_nodes = NewNode[] + pending_perm = Int[] + return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, + parent.result_bbs, ssa_rename, parent.used_ssas, + late_fixup, perm, 1, + new_new_nodes, pending_nodes, pending_perm, + 1, result_offset, parent.active_result_bb) end end @@ -455,7 +472,30 @@ function getindex(compact::IncrementalCompact, idx::Int) end end +function getindex(compact::IncrementalCompact, ssa::SSAValue) + @assert ssa.id < compact.result_idx + return compact.result[ssa.id] +end + +function getindex(compact::IncrementalCompact, ssa::OldSSAValue) + id = ssa.id + if id <= length(compact.ir.stmts) + return compact.ir.stmts[id] + end + id -= length(compact.ir.stmts) + if id <= length(compact.ir.new_nodes) + return compact.ir.new_nodes[id].node + end + id -= length(compact.ir.new_nodes) + return compact.pending_nodes[id].node +end + +function getindex(compact::IncrementalCompact, ssa::NewSSAValue) + return compact.new_new_nodes[ssa.id].node +end + function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) + needs_late_fixup = isa(v, NewSSAValue) if isa(v, SSAValue) compact.used_ssas[v.id] += 1 else @@ -463,44 +503,116 @@ function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) val = ops[] if isa(val, SSAValue) compact.used_ssas[val.id] += 1 + elseif isa(val, NewSSAValue) + needs_late_fixup = true end end end + needs_late_fixup +end + +function resort_pending!(compact) + sort!(compact.pending_perm, DEFAULT_STABLE, Order.By(x->compact.pending_nodes[x].pos)) +end + +function insert_node!(compact::IncrementalCompact, before, @nospecialize(typ), @nospecialize(val), attach_after::Bool=false) + if isa(before, SSAValue) + if before.id < compact.result_idx + count_added_node!(compact, val) + line = compact.result_lines[before.id] + push!(compact.new_new_nodes, NewNode(before.id, attach_after, typ, val, line)) + return NewSSAValue(length(compact.new_new_nodes)) + else + line = compact.ir.lines[before.id] + push!(compact.pending_nodes, NewNode(before.id, attach_after, typ, val, line)) + push!(compact.pending_perm, length(compact.pending_nodes)) + resort_pending!(compact) + os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes)) + push!(compact.ssa_rename, os) + push!(compact.used_ssas, 0) + return os + end + elseif isa(before, OldSSAValue) + pos = before.id + if pos > length(compact.ir.stmts) + #@assert attach_after + entry = compact.pending_nodes[pos - length(compact.ir.stmts) - length(compact.ir.new_nodes)] + pos, attach_after = entry.pos, entry.attach_after + end + line = compact.ir.lines[pos] + push!(compact.pending_nodes, NewNode(pos, attach_after, typ, val, line)) + push!(compact.pending_perm, length(compact.pending_nodes)) + resort_pending!(compact) + os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes)) + push!(compact.ssa_rename, os) + push!(compact.used_ssas, 0) + return os + elseif isa(before, NewSSAValue) + before_entry = compact.new_new_nodes[before.id] + push!(compact.new_new_nodes, NewNode(before_entry.pos, attach_after, typ, val, before_entry.line)) + return NewSSAValue(length(compact.new_new_nodes)) + else + error("Unsupported") + end end -function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int) +function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int, reverse_affinity=false) if compact.result_idx > length(compact.result) @assert compact.result_idx == length(compact.result) + 1 resize!(compact, compact.result_idx) end + refinish = false + if compact.result_idx == first(compact.result_bbs[compact.active_result_bb].stmts) && reverse_affinity + compact.active_result_bb -= 1 + refinish = true + end compact.result[compact.result_idx] = val compact.result_types[compact.result_idx] = typ compact.result_lines[compact.result_idx] = ltable_idx compact.result_flags[compact.result_idx] = 0x00 - count_added_node!(compact, val) + if count_added_node!(compact, val) + push!(compact.late_fixup, compact.result_idx) + end ret = SSAValue(compact.result_idx) compact.result_idx += 1 + refinish && finish_current_bb!(compact) ret end function getindex(view::TypesView, v::OldSSAValue) - return view.ir.ir.types[v.id] + id = v.id + if id <= length(view.ir.ir.types) + return view.ir.ir.types[id] + end + id -= length(view.ir.ir.types) + if id <= length(view.ir.ir.new_nodes) + return view.ir.ir.new_nodes[id].typ + end + id -= length(view.ir.ir.new_nodes) + return view.ir.pending_nodes[id].typ +end + +function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue) + @assert idx.id < compact.result_idx + (compact.result[idx.id] === v) && return + # Kill count for current uses + for ops in userefs(compact.result[idx.id]) + val = ops[] + if isa(val, SSAValue) + @assert compact.used_ssas[val.id] >= 1 + compact.used_ssas[val.id] -= 1 + end + end + compact.result[idx.id] = v + # Add count for new use + if count_added_node!(compact, v) + push!(compact.late_fixup, idx.id) + end end function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int) if idx < compact.result_idx - (compact.result[idx] === v) && return - # Kill count for current uses - for ops in userefs(compact.result[idx]) - val = ops[] - if isa(val, SSAValue) - @assert compact.used_ssas[val.id] >= 1 - compact.used_ssas[val.id] -= 1 - end - end - compact.result[idx] = v - # Add count for new use - count_added_node!(compact, v) + compact[SSAValue(idx)] = v else compact.ir.stmts[idx] = v end @@ -509,10 +621,14 @@ end function getindex(view::TypesView, idx) isa(idx, SSAValue) && (idx = idx.id) - ir = view.ir - if isa(ir, IncrementalCompact) - if idx < ir.result_idx - return ir.result_types[idx] + if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx + return view.ir.result_types[idx] + else + ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir + if idx <= length(ir.types) + return ir.types[idx] + else + return ir.new_nodes[idx - length(ir.types)].typ end ir = ir.ir end @@ -528,47 +644,103 @@ function done(compact::IncrementalCompact, (idx, _a)::Tuple{Int, Int}) return idx > length(compact.ir.stmts) && (compact.new_nodes_idx > length(compact.perm)) end +function getindex(view::TypesView, idx::NewSSAValue) + @assert isa(view.ir, IncrementalCompact) + compact = view.ir + compact.new_new_nodes[idx.id].typ +end + function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int}, processed_idx::Int, result_idx::Int, - ssa_rename::Vector{Any}, used_ssas::Vector{Int}) + ssa_rename::Vector{Any}, used_ssas::Vector{Int}, + do_rename_ssa::Bool) values = Vector{Any}(undef, length(old_values)) for i = 1:length(old_values) isassigned(old_values, i) || continue val = old_values[i] if isa(val, SSAValue) + if do_rename_ssa + if val.id > processed_idx + push!(late_fixup, result_idx) + val = OldSSAValue(val.id) + else + val = renumber_ssa2(val, ssa_rename, used_ssas, do_rename_ssa) + end + else + used_ssas[val.id] += 1 + end + elseif isa(val, OldSSAValue) if val.id > processed_idx push!(late_fixup, result_idx) - val = OldSSAValue(val.id) else - val = renumber_ssa!(val, ssa_rename, true, used_ssas) + # Always renumber these. do_rename_ssa applies only to actual SSAValues + val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, true) end + elseif isa(val, NewSSAValue) + push!(late_fixup, result_idx) end values[i] = val end return values end +function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssa::Vector{Int}, do_rename_ssa::Bool) + id = val.id + if id > length(ssanums) + return val + end + if do_rename_ssa + val = ssanums[id] + end + if isa(val, SSAValue) && used_ssa !== nothing + used_ssa[val.id] += 1 + end + return val +end + +function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssa::Vector{Int}, late_fixup::Vector{Int}, result_idx::Int, do_rename_ssa::Bool) + urs = userefs(stmt) + for op in urs + val = op[] + if isa(val, OldSSAValue) || isa(val, NewSSAValue) + push!(late_fixup, result_idx) + end + if isa(val, SSAValue) + val = renumber_ssa2(val, ssanums, used_ssa, do_rename_ssa) + end + if isa(val, OldSSAValue) || isa(val, NewSSAValue) + push!(late_fixup, result_idx) + end + op[] = val + end + return urs[] +end + function process_node!(result::Vector{Any}, result_idx::Int, ssa_rename::Vector{Any}, late_fixup::Vector{Int}, used_ssas::Vector{Int}, @nospecialize(stmt), - idx::Int, processed_idx::Int) + idx::Int, processed_idx::Int, do_rename_ssa::Bool) ssa_rename[idx] = SSAValue(result_idx) if stmt === nothing ssa_rename[idx] = stmt + elseif isa(stmt, OldSSAValue) + ssa_rename[idx] = ssa_rename[stmt.id] elseif isa(stmt, GotoNode) || isa(stmt, GlobalRef) result[result_idx] = stmt result_idx += 1 elseif isa(stmt, Expr) || isa(stmt, PiNode) || isa(stmt, GotoIfNot) || isa(stmt, ReturnNode) || isa(stmt, UpsilonNode) - result[result_idx] = renumber_ssa!(stmt, ssa_rename, true, used_ssas) + result[result_idx] = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa) result_idx += 1 elseif isa(stmt, PhiNode) - result[result_idx] = PhiNode(stmt.edges, process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas)) + result[result_idx] = PhiNode(stmt.edges, process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)) result_idx += 1 elseif isa(stmt, PhiCNode) - result[result_idx] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas)) + result[result_idx] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)) result_idx += 1 elseif isa(stmt, SSAValue) # identity assign, replace uses of this ssa value with its result - stmt = ssa_rename[stmt.id] + if do_rename_ssa + stmt = ssa_rename[stmt.id] + end ssa_rename[idx] = stmt else # Constant assign, replace uses of this ssa value with its result @@ -576,9 +748,10 @@ function process_node!(result::Vector{Any}, result_idx::Int, ssa_rename::Vector{ end return result_idx end -function process_node!(compact::IncrementalCompact, result_idx::Int, @nospecialize(stmt), idx::Int, processed_idx::Int) +function process_node!(compact::IncrementalCompact, result_idx::Int, @nospecialize(stmt), idx::Int, processed_idx::Int, do_rename_ssa::Bool) return process_node!(compact.result, result_idx, compact.ssa_rename, - compact.late_fixup, compact.used_ssas, stmt, idx, processed_idx) + compact.late_fixup, compact.used_ssas, stmt, idx, processed_idx, + do_rename_ssa) end function resize!(compact::IncrementalCompact, nnewnodes) @@ -621,12 +794,12 @@ function attach_after_stmt_after(compact::IncrementalCompact, idx::Int) entry.pos == idx && entry.attach_after end -function process_newnode!(compact, new_idx, new_node_entry, idx, active_bb) +function process_newnode!(compact, new_idx, new_node_entry, idx, active_bb, do_rename_ssa) old_result_idx = compact.result_idx bb = compact.ir.cfg.blocks[active_bb] compact.result_types[old_result_idx] = new_node_entry.typ compact.result_lines[old_result_idx] = new_node_entry.line - result_idx = process_node!(compact, old_result_idx, new_node_entry.node, new_idx, idx) + result_idx = process_node!(compact, old_result_idx, new_node_entry.node, new_idx, idx, do_rename_ssa) compact.result_idx = result_idx # If this instruction has reverse affinity and we were at the end of a basic block, # finish it now. @@ -651,14 +824,21 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}) compact.new_nodes_idx += 1 new_node_entry = compact.ir.new_nodes[new_idx] new_idx += length(compact.ir.stmts) - return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb) + return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb, true) + elseif !isempty(compact.pending_perm) && + (entry = compact.pending_nodes[compact.pending_perm[1]]; + entry.attach_after ? entry.pos == idx - 1 : entry.pos == idx) + new_idx = popfirst!(compact.pending_perm) + new_node_entry = compact.pending_nodes[new_idx] + new_idx += length(compact.ir.stmts) + length(compact.ir.new_nodes) + return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb, false) end # This will get overwritten in future iterations if # result_idx is not, incremented, but that's ok and expected compact.result_types[old_result_idx] = compact.ir.types[idx] compact.result_lines[old_result_idx] = compact.ir.lines[idx] compact.result_flags[old_result_idx] = compact.ir.flags[idx] - result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx) + result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx, true) stmt_if_any = old_result_idx == result_idx ? nothing : compact.result[old_result_idx] compact.result_idx = result_idx if idx == last(bb.stmts) && !attach_after_stmt_after(compact, idx) @@ -673,22 +853,29 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}) return Pair{Int, Any}(old_result_idx, compact.result[old_result_idx]), (compact.idx, active_bb) end -function maybe_erase_unused!(extra_worklist, compact, idx) - effect_free = stmt_effect_free(compact.result[idx], compact, compact.ir.mod) +function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing) + stmt = compact.result[idx] + stmt === nothing && return false + effect_free = stmt_effect_free(stmt, compact, compact.ir.mod) if effect_free - for ops in userefs(compact.result[idx]) + for ops in userefs(stmt) val = ops[] - if isa(val, SSAValue) + # If the pass we ran inserted new nodes, it's possible for those + # to be outside our used_ssas count. + if isa(val, SSAValue) && val.id <= length(compact.used_ssas) if compact.used_ssas[val.id] == 1 if val.id < idx push!(extra_worklist, val.id) end end compact.used_ssas[val.id] -= 1 + callback(val) end end compact.result[idx] = nothing + return true end + return false end function fixup_phinode_values!(compact, old_values) @@ -701,47 +888,85 @@ function fixup_phinode_values!(compact, old_values) if isa(val, SSAValue) compact.used_ssas[val.id] += 1 end + elseif isa(val, NewSSAValue) + val = SSAValue(length(compact.result) + val.id) end values[i] = val end values end +function fixup_node(compact, @nospecialize(stmt)) + if isa(stmt, PhiNode) + return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) + elseif isa(stmt, PhiCNode) + return PhiCNode(fixup_phinode_values!(compact, stmt.values)) + elseif isa(stmt, NewSSAValue) + return SSAValue(length(compact.result) + stmt.id) + else + urs = userefs(stmt) + urs === () && return stmt + for ur in urs + val = ur[] + if isa(val, NewSSAValue) + ur[] = SSAValue(length(compact.result) + val.id) + end + end + return urs[] + end +end + function just_fixup!(compact) for idx in compact.late_fixup stmt = compact.result[idx] - if isa(stmt, PhiNode) - compact.result[idx] = PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) - else - stmt = stmt::PhiCNode - compact.result[idx] = PhiCNode(fixup_phinode_values!(compact, stmt.values)) + new_stmt = fixup_node(compact, stmt) + (stmt !== new_stmt) && (compact.result[idx] = new_stmt) + end + for idx in 1:length(compact.new_new_nodes) + node = compact.new_new_nodes[idx] + new_stmt = fixup_node(compact, node.node) + if node.node !== new_stmt + compact.new_new_nodes[idx] = NewNode( + node.pos, node.attach_after, node.typ, + new_stmt, node.line) end end end -function finish(compact::IncrementalCompact) - just_fixup!(compact) - # Record this somewhere? - result_idx = compact.result_idx - resize!(compact.result, result_idx-1) - resize!(compact.result_types, result_idx-1) - resize!(compact.result_lines, result_idx-1) - resize!(compact.result_flags, result_idx-1) - bb = compact.result_bbs[end] - compact.result_bbs[end] = BasicBlock(bb, - StmtRange(first(bb.stmts), result_idx-1)) +function simple_dce!(compact) # Perform simple DCE for unused values extra_worklist = Int[] for (idx, nused) in Iterators.enumerate(compact.used_ssas) - idx >= result_idx && break + idx >= compact.result_idx && break nused == 0 || continue maybe_erase_unused!(extra_worklist, compact, idx) end while !isempty(extra_worklist) maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist)) end +end + +function non_dce_finish!(compact::IncrementalCompact) + result_idx = compact.result_idx + resize!(compact.result, result_idx-1) + resize!(compact.result_types, result_idx-1) + resize!(compact.result_lines, result_idx-1) + resize!(compact.result_flags, result_idx-1) + just_fixup!(compact) + bb = compact.result_bbs[end] + compact.result_bbs[end] = BasicBlock(bb, + StmtRange(first(bb.stmts), result_idx-1)) +end + +function finish(compact::IncrementalCompact) + non_dce_finish!(compact) + simple_dce!(compact) + complete(compact) +end + +function complete(compact) cfg = CFG(compact.result_bbs, Int[first(bb.stmts) for bb in compact.result_bbs[2:end]]) - return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, NewNode[]) + return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, compact.new_new_nodes) end function compact!(code::IRCode) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index a7e84e2a9486c..b3e32c436eeb2 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -101,73 +101,110 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks, du, phin end end -function walk_to_def(compact::IncrementalCompact, @nospecialize(def), intermediaries=IdSet{Int}(), allow_phinode::Bool=true, phi_locs=Tuple{Int, Int}[]) - if !isa(def, SSAValue) - return (def, 0) +function simple_walk(compact::IncrementalCompact, defssa::Union{SSAValue, NewSSAValue, OldSSAValue}, pi_callback=(pi,idx)->nothing) + while true + if isa(defssa, OldSSAValue) && already_inserted(compact, defssa) + rename = compact.ssa_rename[defssa.id] + if isa(rename, Union{SSAValue, OldSSAValue, NewSSAValue}) + defssa = rename + continue + end + return rename + end + def = compact[defssa] + if isa(def, PiNode) + pi_callback(def, defssa) + if isa(def.val, SSAValue) + if isa(defssa, OldSSAValue) && !already_inserted(compact, defssa) + defssa = OldSSAValue(def.val.id) + else + defssa = def.val + end + else + return def.val + end + elseif isa(def, Union{SSAValue, OldSSAValue, NewSSAValue}) + pi_callback(def, defssa) + defssa = def + elseif isa(def, Union{PhiNode, PhiCNode, Expr, GlobalRef}) + return defssa + else + return def + end end - orig_defidx = defidx = def.id +end + +function simple_walk_constraint(compact, defidx, typeconstraint = types(compact)[defidx]) + callback = (pi, _)->isa(pi, PiNode) && (typeconstraint = typeintersect(typeconstraint, pi.typ)) + def = simple_walk(compact, defidx, callback) + def, typeconstraint +end + +""" + walk_to_defs(compact, val, intermediaries) + +Starting at `val` walk use-def chains to get all the leaves feeding into +this val (pruning those leaves rules out by path conditions). +""" +function walk_to_defs(compact, defssa, typeconstraint, visited_phinodes=Any[]) # Step 2: Figure out what the struct is defined as - def = compact[defidx] - typeconstraint = types(compact)[defidx] + def = compact[defssa] ## Track definitions through PiNode/PhiNode found_def = false ## Track which PhiNodes, SSAValue intermediaries ## we forwarded through. - while true - if isa(def, PiNode) - push!(intermediaries, defidx) - typeconstraint = typeintersect(typeconstraint, def.typ) - if isa(def.val, SSAValue) - defidx = def.val.id - def = compact[defidx] - else - def = def.val - end - continue - elseif isa(def, FastForward) - append!(phi_locs, def.phi_locs) - def = def.to - elseif isa(def, PhiNode) - # For now, we don't track setfields structs through phi nodes - allow_phinode || break - push!(intermediaries, defidx) + visited = IdSet{Any}() + worklist = Tuple{Any, Any}[] + leaves = Any[] + push!(worklist, (defssa, typeconstraint)) + while !isempty(worklist) + defssa, typeconstraint = pop!(worklist) + push!(visited, defssa) + def = compact[defssa] + if isa(def, PhiNode) + push!(visited_phinodes, defssa) possible_predecessors = let def=def, typeconstraint=typeconstraint collect(Iterators.filter(1:length(def.edges)) do n isassigned(def.values, n) || return false - value = def.values[n] - edge_typ = widenconst(compact_exprtype(compact, value)) + val = def.values[n] + if isa(defssa, OldSSAValue) && isa(val, SSAValue) + val = OldSSAValue(val.id) + end + edge_typ = widenconst(compact_exprtype(compact, val)) return typeintersect(edge_typ, typeconstraint) !== Union{} end) end - # For now, only look at unique predecessors - if length(possible_predecessors) == 1 - n = possible_predecessors[1] + for n in possible_predecessors pred = def.edges[n] val = def.values[n] - if isa(val, SSAValue) - push!(phi_locs, (pred, defidx)) - defidx = val.id - def = compact[defidx] - elseif def == val + if isa(defssa, OldSSAValue) && isa(val, SSAValue) + val = OldSSAValue(val.id) + end + if isa(val, Union{SSAValue, OldSSAValue, NewSSAValue}) + new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint) + if isa(new_def, Union{SSAValue, OldSSAValue, NewSSAValue}) + if !(new_def in visited) + push!(worklist, (new_def, new_constraint)) + end + continue + end + val = new_def + end + if def == val # This shouldn't really ever happen, but # patterns like this can occur in dead code, # so bail out. break else - def = val + push!(leaves, val) end continue end - elseif isa(def, SSAValue) - push!(intermediaries, defidx) - defidx = def.id - def = compact[def.id] - continue + else + push!(leaves, defssa) end - found_def = true - break end - found_def ? (def, defidx) : nothing + leaves end function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr) @@ -178,24 +215,318 @@ function process_immutable_preserve(new_preserves::Vector{Any}, compact::Increme end end -struct FastForward - to::SSAValue - phi_locs::Vector{Tuple{Int, Int}} +function already_inserted(compact::IncrementalCompact, old::OldSSAValue) + id = old.id + if id < length(compact.ir.stmts) + return id < compact.idx + end + id -= length(compact.ir.stmts) + if id < length(compact.ir.new_nodes) + error() + end + id -= length(compact.ir.new_nodes) + @assert id <= length(compact.pending_nodes) + return !(id in compact.pending_perm) +end + +function is_pending(compact::IncrementalCompact, old::OldSSAValue) + return old.id > length(compact.ir.stmts) + length(compact.ir.new_nodes) +end + +function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt), + @nospecialize(result_t), field::Int, leaves::Vector{Any}) + # For every leaf, the lifted value + lifted_leaves = IdDict{Any, Any}() + maybe_undef = false + for leaf in leaves + leaf_key = leaf + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + if isa(leaf, OldSSAValue) && already_inserted(compact, leaf) + leaf = compact.ssa_rename[leaf.id] + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + leaf = simple_walk(compact, leaf) + end + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + def = compact[leaf] + else + def = leaf + end + else + def = compact[leaf] + end + if is_tuple_call(compact.ir, def) && isa(field, Int) && 1 <= field < length(def.args) + lifted = def.args[1+field] + if isa(leaf, OldSSAValue) && isa(lifted, SSAValue) + lifted = OldSSAValue(lifted.id) + end + lifted_leaves[leaf_key] = RefValue{Any}(lifted) + continue + elseif isexpr(def, :new) + typ = def.typ + if isa(typ, UnionAll) + typ = unwrap_unionall(typ) + end + (isa(typ, DataType) && (!typ.abstract)) || return nothing + @assert !typ.mutable + field = try_compute_fieldidx_expr(typ, stmt) + field === nothing && return nothing + if length(def.args) < 1 + field + ftyp = fieldtype(typ, field) + if !isbits(ftyp) + # On this branch, this will be a guaranteed UndefRefError. + # We use the regular undef mechanic to lift this to a boolean slot + maybe_undef = true + lifted_leaves[leaf_key] = nothing + continue + end + return nothing + # Expand the Expr(:new) to include it's element Expr(:new) nodes up until the one we want + compact[leaf] = nothing + for i = (length(def.args) + 1):(1+field) + ftyp = fieldtype(typ, i - 1) + isbits(ftyp) || return nothing + push!(def.args, insert_node!(compact, leaf, result_t, Expr(:new, ftyp))) + end + compact[leaf] = def + end + lifted = def.args[1+field] + if isa(leaf, OldSSAValue) && isa(lifted, SSAValue) + lifted = OldSSAValue(lifted.id) + end + lifted_leaves[leaf_key] = RefValue{Any}(lifted) + continue + else + typ = compact_exprtype(compact, leaf) + if !isa(typ, Const) + # If the leaf is an old ssa value, insert a getfield here + # We will revisit this getfield later when compaction gets + # to the appropriate point. + # N.B.: This can be a bit dangerous because it can lead to + # infinite loops if we accidentally insert a node just ahead + # of where we are + if isa(leaf, OldSSAValue) && (isa(field, Int) || isa(field, Symbol)) + (isa(typ, DataType) && (!typ.abstract)) || return nothing + @assert !typ.mutable + # If there's the potential for an undefref error on access, we cannot insert a getfield + if field > typ.ninitialized && !isbits(fieldtype(typ, field)) + return nothing + lifted_leaves[leaf] = RefValue{Any}(insert_node!(compact, leaf, make_MaybeUndef(result_t), Expr(:call, :unchecked_getfield, SSAValue(leaf.id), field), true)) + maybe_undef = true + else + return nothing + lifted_leaves[leaf] = RefValue{Any}(insert_node!(compact, leaf, result_t, Expr(:call, getfield, SSAValue(leaf.id), field), true)) + end + continue + end + return nothing + end + leaf = typ.val + # Fall through to below + end + elseif isa(leaf, QuoteNode) + leaf = leaf.value + elseif isa(leaf, Union{Argument, Expr}) + return nothing + end + isimmutable(leaf) || return nothing + isdefined(leaf, field) || return nothing + val = getfield(leaf, field) + is_inlineable_constant(val) || return nothing + lifted_leaves[leaf_key] = RefValue{Any}(quoted(val)) + end + lifted_leaves, maybe_undef +end + +make_MaybeUndef(typ) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ) + +const AnySSAValue = Union{SSAValue, OldSSAValue, NewSSAValue} + +function lift_comparison!(compact::IncrementalCompact, idx::Int, + @nospecialize(c1), @nospecialize(c2), stmt::Expr, + lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}) + if isa(c1, Const) + cmp = c1 + typeconstraint = widenconst(c2) + val = stmt.args[3] + else + cmp = c2 + typeconstraint = widenconst(c1) + val = stmt.args[2] + end + + is_type_only = isdefined(typeof(cmp), :instance) + + if isa(val, Union{OldSSAValue, SSAValue}) + val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) + end + + visited_phinodes = Any[] + if isa(val, Union{OldSSAValue, SSAValue, NewSSAValue}) && isa(compact[val], PhiNode) + leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes) + else + leaves = [val] + end + + # Let's check if we evaluate the comparison for each one of the leaves + lifted_leaves = IdDict{Any, Any}() + for leaf in leaves + r = egal_tfunc(compact_exprtype(compact, leaf), cmp) + if isa(r, Const) + lifted_leaves[leaf] = RefValue{Any}(r.val) + else + # TODO: In some cases it might be profitable to hoist the === + # here. + return + end + end + + lifted_val = perform_lifting!(compact, visited_phinodes, cmp, lifting_cache, Bool, lifted_leaves, val) + + #global assertion_counter + #assertion_counter::Int += 1 + #insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true) + #return + compact[idx] = lifted_val +end + +struct LiftedPhi + ssa::AnySSAValue + node::Any + need_argupdate::Bool end -function getfield_elim_pass!(ir::IRCode, domtree::DomTree) +function perform_lifting!(compact::IncrementalCompact, + visited_phinodes::Vector{Any}, @nospecialize(cache_key), + lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}, + @nospecialize(result_t), lifted_leaves::IdDict{Any, Any}, @nospecialize(stmt_val)) + reverse_mapping = IdDict{Any, Any}(ssa => id for (id, ssa) in enumerate(visited_phinodes)) + + # Insert PhiNodes + lifted_phis = LiftedPhi[] + for item in visited_phinodes + if (item, cache_key) in keys(lifting_cache) + ssa = lifting_cache[Pair{AnySSAValue, Any}(item, cache_key)] + push!(lifted_phis, LiftedPhi(ssa, compact[ssa], false)) + continue + end + n = PhiNode() + ssa = insert_node!(compact, item, result_t, n) + lifting_cache[Pair{AnySSAValue, Any}(item, cache_key)] = ssa + push!(lifted_phis, LiftedPhi(ssa, n, true)) + end + + # Fix up arguments + for (old_node_ssa, lf) in zip(visited_phinodes, lifted_phis) + old_node = compact[old_node_ssa] + new_node = lf.node + lf.need_argupdate || continue + for i = 1:length(old_node.edges) + edge = old_node.edges[i] + isassigned(old_node.values, i) || continue + val = old_node.values[i] + orig_val = val + if isa(old_node_ssa, OldSSAValue) && !is_pending(compact, old_node_ssa) && !already_inserted(compact, old_node_ssa) && isa(val, SSAValue) + val = OldSSAValue(val.id) + end + if isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) + val = simple_walk(compact, val) + end + if val in keys(lifted_leaves) + push!(new_node.edges, edge) + lifted_val = lifted_leaves[val] + if lifted_val === nothing + resize!(new_node.values, length(new_node.values)+1) + continue + end + lifted_val = lifted_val.x + if isa(lifted_val, Union{NewSSAValue, SSAValue, OldSSAValue}) + lifted_val = simple_walk(compact, lifted_val) + end + push!(new_node.values, lifted_val) + elseif isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) && val in keys(reverse_mapping) + push!(new_node.edges, edge) + push!(new_node.values, lifted_phis[reverse_mapping[val]].ssa) + else + # Probably ignored by path condition, skip this + end + end + end + + for lf in lifted_phis + count_added_node!(compact, lf.node) + end + + # Fixup the stmt itself + if isa(stmt_val, Union{SSAValue, OldSSAValue}) + stmt_val = simple_walk(compact, stmt_val) + end + if stmt_val in keys(lifted_leaves) + stmt_val = lifted_leaves[stmt_val] + @assert stmt_val !== nothing + stmt_val = stmt_val.x + else + isa(stmt_val, Union{SSAValue, OldSSAValue}) && stmt_val in keys(reverse_mapping) + stmt_val = lifted_phis[reverse_mapping[stmt_val]].ssa + end + + return stmt_val +end + +assertion_counter = 0 +function getfield_elim_pass!(ir::IRCode, domtree) compact = IncrementalCompact(ir) insertions = Vector{Any}() defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}() + lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + revisit_worklist = Int[] + #ndone, nmax = 0, 200 for (idx, stmt) in compact isa(stmt, Expr) || continue - is_getfield = false + #ndone >= nmax && continue + #ndone += 1 + result_t = compact_exprtype(compact, SSAValue(idx)) + is_getfield = is_setfield = false is_ccall = false + is_unchecked = false # Step 1: Check whether the statement we're looking at is a getfield/setfield! if is_known_call(stmt, setfield!, compact) is_setfield = true elseif is_known_call(stmt, getfield, compact) is_getfield = true + elseif is_known_call(stmt, isa, compact) + # TODO + continue + elseif is_known_call(stmt, typeassert, compact) + # Canonicalize + # X = typeassert(Y, T)::S + # into + # typeassert(Y, T) + # X = PiNode(Y, S) + # N.B.: Inference may have a more precise type for `S`, than + # just T, but from here on out, there's no problem with + # using just using that. + # so subsequent analysis only has to deal with the latter + # form. TODO: This isn't the best place to put this. + # Also, we should probably have a version of typeassert + # that's defined not to return its value to make life easier + # for the backend. + pi = insert_node_here!(compact, + PiNode(stmt.args[2], compact.result_types[idx]), compact.result_types[idx], + compact.result_lines[idx], true) + compact.ssa_rename[compact.idx-1] = pi + continue + elseif is_known_call(stmt, (===), compact) + c1 = compact_exprtype(compact, stmt.args[2]) + c2 = compact_exprtype(compact, stmt.args[3]) + if !(isa(c1, Const) || isa(c2, Const)) + continue + end + (isa(c1, Const) && isa(c2, Const)) && continue + lift_comparison!(compact, idx, c1, c2, stmt, lifting_cache) + continue + elseif isexpr(stmt, :call) && stmt.args[1] == :unchecked_getfield + is_getfield = true + is_unchecked = true elseif isexpr(stmt, :foreigncall) nccallargs = stmt.args[5] new_preserves = Any[] @@ -203,9 +534,10 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) for (pidx, preserved_arg) in enumerate(old_preserves) intermediaries = IdSet() isa(preserved_arg, SSAValue) || continue - def = walk_to_def(compact, preserved_arg, intermediaries, false) - def !== nothing || continue - (def, defidx) = def + def = simple_walk(compact, preserved_arg, (pi, ssa)->push!(intermediaries, ssa.id)) + isa(def, SSAValue) || continue + defidx = def.id + def = compact[defidx] if is_tuple_call(compact, def) process_immutable_preserve(new_preserves, compact, def) old_preserves[pidx] = nothing @@ -244,61 +576,78 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) isa(field, QuoteNode) && (field = field.value) isa(field, Union{Int, Symbol}) || continue - intermediaries = IdSet() - phi_locs = Tuple{Int, Int}[] - def = walk_to_def(compact, stmt.args[2], intermediaries, is_getfield, phi_locs) - def !== nothing || continue - (def, defidx) = def + struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2]))) + isa(struct_typ, DataType) || continue - if !is_getfield - (defidx == 0) && continue - mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse())) - push!(defuse.defs, idx) + def, typeconstraint = stmt.args[2], struct_typ + + if struct_typ.mutable + isa(def, SSAValue) || continue + intermediaries = IdSet() + def = simple_walk(compact, def, (pi, ssa)->push!(intermediaries, ssa.id)) + # Mutable stuff here + isa(def, SSAValue) || continue + mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse())) + if is_setfield + push!(defuse.defs, idx) + else + push!(defuse.uses, idx) + end union!(mid, intermediaries) continue + elseif is_setfield + continue end - # Step 3: Check if the definition we eventually end up at is either - # a tuple(...) call or Expr(:new) and perform replacement. - if is_tuple_call(compact, def) && isa(field, Int) && 1 <= field < length(def.args) - forwarded = def.args[1+field] - elseif isexpr(def, :new) - typ = def.typ - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) - end - isa(typ, DataType) || continue - if typ.mutable - @assert defidx != 0 - mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse())) - push!(defuse.uses, idx) - union!(mid, intermediaries) - continue - end - field = try_compute_fieldidx_expr(typ, stmt) - field === nothing && continue - forwarded = def.args[1+field] - else - obj = compact_exprtype(compact, def) - isa(obj, Const) || continue - obj = obj.val - isimmutable(obj) || continue - field = try_compute_fieldidx_expr(typeof(obj), stmt) - field === nothing && continue - isdefined(obj, field) || continue - val = getfield(obj, field) - is_inlineable_constant(val) || continue - forwarded = quoted(val) + + if isa(def, Union{OldSSAValue, SSAValue}) + def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint) end - # Step 4: Remember any phinodes we need to insert - if !isempty(phi_locs) && isa(forwarded, SSAValue) - # TODO: We have have to use BB ids for phi_locs - # to avoid index invalidation. - push!(insertions, idx) - compact[idx] = FastForward(forwarded, phi_locs) + + visited_phinodes = Any[] + if isa(def, Union{OldSSAValue, SSAValue, NewSSAValue}) && isa(compact[def], PhiNode) + leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes) else - compact[idx] = forwarded + leaves = Any[def] end + + isempty(leaves) && continue + + field = try_compute_fieldidx_expr(struct_typ, stmt) + field === nothing && continue + + r = lift_leaves(compact, stmt, result_t, field, leaves) + r === nothing && continue + lifted_leaves, any_undef = r + + if any_undef + result_t = make_MaybeUndef(result_t) + end + +# @Base.show result_t +# @Base.show stmt +# for (k,v) in lifted_leaves +# @Base.show (k, v) +# if isa(k, AnySSAValue) +# @Base.show compact[k] +# end +# if isa(v, RefValue) && isa(v.x, AnySSAValue) +# @Base.show compact[v.x] +# end +# end + val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2]) + + # Insert the undef check if necessary + if any_undef && !is_unchecked + insert_node!(compact, SSAValue(idx), Nothing, Expr(:undefcheck, :getfield, val)) + end + + global assertion_counter + assertion_counter::Int += 1 + #insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true) + #continue + compact[idx] = val end + ir = finish(compact) # Now go through any mutable structs and see which ones we can eliminate for (idx, (intermediaries, defuse)) in defuses @@ -389,25 +738,87 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) ir[SSAValue(use)] = new_expr end end - for idx in insertions - # For non-dominating load-store forward, we may have to insert extra phi nodes - # TODO: Can use the domtree to eliminate unnecessary phis, but ok for now - ff = ir.stmts[idx] - ff === nothing && continue # May have been DCE'd if there were no more uses - ff = ff::FastForward - forwarded = ff.to - if isa(forwarded, SSAValue) - forwarded_typ = ir.types[forwarded.id] - for (pred, pos) in reverse!(ff.phi_locs) - node = PhiNode() - push!(node.edges, pred) - push!(node.values, forwarded) - forwarded = insert_node!(ir, pos, forwarded_typ, node) + ir +end + +function adce_erase!(phi_uses, extra_worklist, compact, idx) + if isa(compact.result[idx], PhiNode) + maybe_erase_unused!(extra_worklist, compact, idx, val->phi_uses[val.id]-=1) + else + maybe_erase_unused!(extra_worklist, compact, idx) + end +end + +function count_uses(stmt, uses) + for ur in userefs(stmt) + if isa(ur[], SSAValue) + uses[ur[].id] += 1 + end + end +end + +function mark_phi_cycles(compact, safe_phis, phi) + worklist = Int[] + push!(worklist, phi) + while !isempty(worklist) + phi = pop!(worklist) + push!(safe_phis, phi) + for ur in userefs(compact.result[phi]) + val = ur[] + isa(val, SSAValue) || continue + isa(compact[val], PhiNode) || continue + (val.id in safe_phis) && continue + push!(worklist, val.id) + end + end +end + +function adce_pass!(ir) + phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) + all_phis = Int[] + compact = IncrementalCompact(ir) + for (idx, stmt) in compact + if isa(stmt, PhiNode) + push!(all_phis, idx) + end + end + non_dce_finish!(compact) + for phi in all_phis + count_uses(compact.result[phi], phi_uses) + end + # Perform simple DCE for unused values + extra_worklist = Int[] + for (idx, nused) in Iterators.enumerate(compact.used_ssas) + idx >= compact.result_idx && break + nused == 0 || continue + adce_erase!(phi_uses, extra_worklist, compact, idx) + end + while !isempty(extra_worklist) + adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + end + # Go back and erase any phi cycles + changed = true + while changed + changed = false + safe_phis = IdSet{Int}() + for phi in all_phis + # Save any phi cycles that have non-phi uses + if compact.used_ssas[phi] - phi_uses[phi] != 0 + mark_phi_cycles(compact, safe_phis, phi) + end + end + for phi in all_phis + if !(phi in safe_phis) + push!(extra_worklist, phi) + end + end + while !isempty(extra_worklist) + if adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + changed = true end end - ir.stmts[idx] = forwarded end - ir + complete(compact) end function type_lift_pass!(ir::IRCode) @@ -415,7 +826,8 @@ function type_lift_pass!(ir::IRCode) has_non_type_ctx_uses = IdSet{Int}() lifted_undef = IdDict{Int, Any}() for (idx, stmt) in pairs(ir.stmts) - if stmt isa Expr && (stmt.head === :isdefined || stmt.head === :undefcheck) + stmt isa Expr || continue + if (stmt.head === :isdefined || stmt.head === :undefcheck) val = (stmt.head === :isdefined) ? stmt.args[1] : stmt.args[2] # undef can only show up by being introduced in a phi # node (or an UpsilonNode() argument to a PhiC node), @@ -427,11 +839,11 @@ function type_lift_pass!(ir::IRCode) end continue end - worklist = Tuple{Int, Int, SSAValue, Int}[(val.id, 0, SSAValue(0), 0)] stmt_id = val.id while isa(ir.stmts[stmt_id], PiNode) stmt_id = ir.stmts[stmt_id].val.id end + worklist = Tuple{Int, Int, SSAValue, Int}[(stmt_id, 0, SSAValue(0), 0)] def = ir.stmts[stmt_id] if !isa(def, PhiNode) && !isa(def, PhiCNode) if stmt.head === :isdefined diff --git a/base/compiler/ssair/queries.jl b/base/compiler/ssair/queries.jl index 3ec426b43b931..9b3d58522d8b0 100644 --- a/base/compiler/ssair/queries.jl +++ b/base/compiler/ssair/queries.jl @@ -63,7 +63,7 @@ function abstract_eval_ssavalue(s::SSAValue, src::IncrementalCompact) end function compact_exprtype(compact::IncrementalCompact, @nospecialize(value)) - if isa(value, Union{SSAValue, OldSSAValue}) + if isa(value, Union{SSAValue, OldSSAValue, NewSSAValue}) return types(compact)[value] elseif isa(value, Argument) return compact.ir.argtypes[value.n] diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index d5322dad9b259..067c0f9e7a0c5 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -11,8 +11,20 @@ function Base.show(io::IO, cfg::CFG) end end -print_ssa(io::IO, val) = isa(val, SSAValue) ? Base.print(io, "%$(val.id)") : - isa(val, Argument) ? Base.print(io, "%%$(val.n)") : Base.print(io, val) +function print_ssa(io::IO, val) + if isa(val, SSAValue) + Base.print(io, "%$(val.id)") + elseif isa(val, Argument) + Base.print(io, "%%$(val.n)") + else + try + Base.print(io, val) + catch + Base.print(io, "") + end + end +end + function print_node(io::IO, idx, stmt, used, maxsize; color = true, print_typ=true) if idx in used pad = " "^(maxsize-length(string(idx))) @@ -97,7 +109,6 @@ function Base.show(io::IO, code::IRCode) maxused = maximum(used) maxsize = length(string(maxused)) end - for idx in eachindex(code.stmts) if !isassigned(code.stmts, idx) # This is invalid, but do something useful rather @@ -164,10 +175,19 @@ function Base.show(io::IO, code::IRCode) end typ = code.types[idx] print_ssa_typ = !isa(stmt, PiNode) && idx in used - print_node(io, idx, stmt, used, maxsize, - print_typ=!print_ssa_typ || (isa(stmt, Expr) && typ != stmt.typ)) + try + print_node(io, idx, stmt, used, maxsize, + print_typ=!print_ssa_typ || (isa(stmt, Expr) && typ != stmt.typ)) + catch + print(io, "") + end if print_ssa_typ - Base.printstyled(io, "::$(typ)", color=:red) + typ_str = try + string(typ) + catch + "" + end + Base.printstyled(io, "::$(typ_str)", color=:red) end Base.println(io) end diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index 36238b3530879..54399c39919f2 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -28,6 +28,10 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, error() end end + elseif isa(op, Union{OldSSAValue, NewSSAValue}) + #@Base.show ir + @verify_error "Left over SSA marker" + error() elseif isa(op, Union{SlotNumber, TypedSlot}) @verify_error "Left over slot detected in converted IR" error() @@ -109,7 +113,7 @@ function verify_ir(ir::IRCode) if isa(stmt, Expr) || isa(stmt, ReturnNode) # TODO: make sure everything has line info if !(stmt isa ReturnNode && !isdefined(stmt, :val)) # not actually a return node, but an unreachable marker if ir.lines[idx] <= 0 - @verify_error "Missing line number information for statement $idx of $ir" + #@verify_error "Missing line number information for statement $idx of $ir" end end end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 910dbfa3953fb..1f55b32089574 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -207,25 +207,25 @@ add_tfunc(ifelse, 3, 3, end return tmerge(x, y) end, 1) -add_tfunc(===, 2, 2, - function (@nospecialize(x), @nospecialize(y)) - if isa(x, Const) && isa(y, Const) - return Const(x.val === y.val) - elseif typeintersect(widenconst(x), widenconst(y)) === Bottom - return Const(false) - elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) || - (isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance)) - return Const(true) - elseif isa(x, Conditional) && isa(y, Const) - y.val === false && return Conditional(x.var, x.elsetype, x.vtype) - y.val === true && return x - return x - elseif isa(y, Conditional) && isa(x, Const) - x.val === false && return Conditional(y.var, y.elsetype, y.vtype) - x.val === true && return y - end - return Bool - end, 1) +function egal_tfunc(@nospecialize(x), @nospecialize(y)) + if isa(x, Const) && isa(y, Const) + return Const(x.val === y.val) + elseif typeintersect(widenconst(x), widenconst(y)) === Bottom + return Const(false) + elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) || + (isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance)) + return Const(true) + elseif isa(x, Conditional) && isa(y, Const) + y.val === false && return Conditional(x.var, x.elsetype, x.vtype) + y.val === true && return x + return x + elseif isa(y, Conditional) && isa(x, Const) + x.val === false && return Conditional(y.var, y.elsetype, y.vtype) + x.val === true && return y + end + return Bool +end +add_tfunc(===, 2, 2, egal_tfunc, 1) function isdefined_tfunc(@nospecialize(args...)) arg1 = args[1] if isa(arg1, Const) @@ -436,12 +436,46 @@ function const_datatype_getfield_tfunc(sv, fld) return nothing end +function fieldcount_noerror(@nospecialize t) + if t isa UnionAll || t isa Union + t = ccall(:jl_argument_datatype, Any, (Any,), t) + if t === nothing + return nothing + end + t = t::DataType + elseif t == Union{} + return 0 + end + if !(t isa DataType) + return nothing + end + if t.name === NamedTuple.body.body.name + names, types = t.parameters + if names isa Tuple + return length(names) + end + if types isa DataType && types <: Tuple + return fieldcount_noerror(types) + end + abstr = true + else + abstr = t.abstract || (t.name === Tuple.name && isvatuple(t)) + end + if abstr + return nothing + end + return length(t.types) +end + + function try_compute_fieldidx(@nospecialize(typ), @nospecialize(field)) if isa(field, Symbol) field = fieldindex(typ, field, false) field == 0 && return nothing elseif isa(field, Integer) - (1 <= field <= fieldcount(typ)) || return nothing + max_fields = fieldcount_noerror(typ) + max_fields === nothing && return nothing + (1 <= field <= max_fields) || return nothing else return nothing end diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 634d63a7a8083..b339a0e7b6614 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -3327,13 +3327,13 @@ f(x) = yt(x) (or (simple-atom? e) (symbol? e) (and (pair? e) (memq (car e) '(quote inert top core globalref outerref - slot static_parameter boundscheck copyast))))) + slot static_parameter boundscheck))))) (define (valid-ir-rvalue? lhs e) (or (ssavalue? lhs) (valid-ir-argument? e) (and (symbol? lhs) (pair? e) - (memq (car e) '(new the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin))))) + (memq (car e) '(new the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin copyast))))) (define (valid-ir-return? e) ;; returning lambda directly is needed for @generated @@ -3439,7 +3439,7 @@ f(x) = yt(x) (cdr lst)))) (simple? (every (lambda (x) (or (simple-atom? x) (symbol? x) (and (pair? x) - (memq (car x) '(quote inert top core globalref outerref copyast boundscheck))))) + (memq (car x) '(quote inert top core globalref outerref boundscheck))))) lst))) (let loop ((lst lst) (vals '())) @@ -3454,7 +3454,7 @@ f(x) = yt(x) (not (simple-atom? arg)) (not (simple-atom? aval)) (not (and (pair? arg) - (memq (car arg) '(& quote inert top core globalref outerref copyast boundscheck)))) + (memq (car arg) '(& quote inert top core globalref outerref boundscheck)))) (not (and (symbol? aval) ;; function args are immutable and always assigned (memq aval (lam:args lam)))) (not (and (symbol? arg)