diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 18dfb82161653..d0fbaa11d6bc2 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -17,7 +17,7 @@ include("compiler/ssair/passes.jl") include("compiler/ssair/inlining2.jl") include("compiler/ssair/verify.jl") include("compiler/ssair/legacy.jl") -@isdefined(Base) && include("compiler/ssair/show.jl") +#@isdefined(Base) && include("compiler/ssair/show.jl") function normalize_expr(stmt::Expr) if stmt.head === :gotoifnot @@ -159,14 +159,20 @@ end function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode}, sv::OptimizationState) ir = just_construct_ssa(ci, copy(ci.code), nargs, linetable) + #@Base.show ("after_construct", ir) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) #@timeit "verify 1" verify_ir(ir) @timeit "Inlining" ir = ssa_inlining_pass!(ir, linetable, sv) #@timeit "verify 2" verify_ir(ir) @timeit "domtree 2" domtree = construct_domtree(ir.cfg) + ir = compact!(ir) + #@Base.show ("before_sroa", ir) @timeit "SROA" ir = getfield_elim_pass!(ir, domtree) - @timeit "compact 2" ir = compact!(ir) + #@Base.show ir.new_nodes + #@Base.show ("after_sroa", ir) + ir = adce_pass!(ir) + #@Base.show ("after_adce", ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) #@Base.show ir diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 97a202855a1f6..0b92e5f0b615f 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -413,6 +413,12 @@ mutable struct IncrementalCompact # This could be Stateful, but bootstrapping doesn't like that perm::Vector{Int} new_nodes_idx::Int + # This supports insertion while compacting + new_new_nodes::Vector{NewNode} # New nodes that were before the compaction point at insertion time + # TODO: Switch these two to a min-heap of some sort + pending_nodes::Vector{NewNode} # New nodes that were after the compaction point at insertion time + pending_perm::Vector{Int} + # State idx::Int result_idx::Int active_result_bb::Int @@ -427,7 +433,12 @@ mutable struct IncrementalCompact used_ssas = fill(0, new_len) ssa_rename = Any[SSAValue(i) for i = 1:new_len] late_fixup = Vector{Int}() - return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1, 1, 1, 1) + new_new_nodes = NewNode[] + pending_nodes = NewNode[] + pending_perm = Int[] + return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1, + new_new_nodes, pending_nodes, pending_perm, + 1, 1, 1) end # For inlining @@ -437,8 +448,14 @@ mutable struct IncrementalCompact ssa_rename = Any[SSAValue(i) for i = 1:new_len] used_ssas = fill(0, new_len) late_fixup = Vector{Int}() - return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, parent.result_bbs, - ssa_rename, parent.used_ssas, late_fixup, perm, 1, 1, result_offset, parent.active_result_bb) + new_new_nodes = NewNode[] + pending_nodes = NewNode[] + pending_perm = Int[] + return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, + parent.result_bbs, ssa_rename, parent.used_ssas, + late_fixup, perm, 1, + new_new_nodes, pending_nodes, pending_perm, + 1, result_offset, parent.active_result_bb) end end @@ -455,7 +472,30 @@ function getindex(compact::IncrementalCompact, idx::Int) end end +function getindex(compact::IncrementalCompact, ssa::SSAValue) + @assert ssa.id < compact.result_idx + return compact.result[ssa.id] +end + +function getindex(compact::IncrementalCompact, ssa::OldSSAValue) + id = ssa.id + if id <= length(compact.ir.stmts) + return compact.ir.stmts[id] + end + id -= length(compact.ir.stmts) + if id <= length(compact.ir.new_nodes) + return compact.ir.new_nodes[id].node + end + id -= length(compact.ir.new_nodes) + return compact.pending_nodes[id].node +end + +function getindex(compact::IncrementalCompact, ssa::NewSSAValue) + return compact.new_new_nodes[ssa.id].node +end + function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) + needs_late_fixup = isa(v, NewSSAValue) if isa(v, SSAValue) compact.used_ssas[v.id] += 1 else @@ -463,44 +503,116 @@ function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) val = ops[] if isa(val, SSAValue) compact.used_ssas[val.id] += 1 + elseif isa(val, NewSSAValue) + needs_late_fixup = true end end end + needs_late_fixup +end + +function resort_pending!(compact) + sort!(compact.pending_perm, DEFAULT_STABLE, Order.By(x->compact.pending_nodes[x].pos)) +end + +function insert_node!(compact::IncrementalCompact, before, @nospecialize(typ), @nospecialize(val), attach_after::Bool=false) + if isa(before, SSAValue) + if before.id < compact.result_idx + count_added_node!(compact, val) + line = compact.result_lines[before.id] + push!(compact.new_new_nodes, NewNode(before.id, attach_after, typ, val, line)) + return NewSSAValue(length(compact.new_new_nodes)) + else + line = compact.ir.lines[before.id] + push!(compact.pending_nodes, NewNode(before.id, attach_after, typ, val, line)) + push!(compact.pending_perm, length(compact.pending_nodes)) + resort_pending!(compact) + os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes)) + push!(compact.ssa_rename, os) + push!(compact.used_ssas, 0) + return os + end + elseif isa(before, OldSSAValue) + pos = before.id + if pos > length(compact.ir.stmts) + #@assert attach_after + entry = compact.pending_nodes[pos - length(compact.ir.stmts) - length(compact.ir.new_nodes)] + pos, attach_after = entry.pos, entry.attach_after + end + line = compact.ir.lines[pos] + push!(compact.pending_nodes, NewNode(pos, attach_after, typ, val, line)) + push!(compact.pending_perm, length(compact.pending_nodes)) + resort_pending!(compact) + os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes)) + push!(compact.ssa_rename, os) + push!(compact.used_ssas, 0) + return os + elseif isa(before, NewSSAValue) + before_entry = compact.new_new_nodes[before.id] + push!(compact.new_new_nodes, NewNode(before_entry.pos, attach_after, typ, val, before_entry.line)) + return NewSSAValue(length(compact.new_new_nodes)) + else + error("Unsupported") + end end -function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int) +function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int, reverse_affinity=false) if compact.result_idx > length(compact.result) @assert compact.result_idx == length(compact.result) + 1 resize!(compact, compact.result_idx) end + refinish = false + if compact.result_idx == first(compact.result_bbs[compact.active_result_bb].stmts) && reverse_affinity + compact.active_result_bb -= 1 + refinish = true + end compact.result[compact.result_idx] = val compact.result_types[compact.result_idx] = typ compact.result_lines[compact.result_idx] = ltable_idx compact.result_flags[compact.result_idx] = 0x00 - count_added_node!(compact, val) + if count_added_node!(compact, val) + push!(compact.late_fixup, compact.result_idx) + end ret = SSAValue(compact.result_idx) compact.result_idx += 1 + refinish && finish_current_bb!(compact) ret end function getindex(view::TypesView, v::OldSSAValue) - return view.ir.ir.types[v.id] + id = v.id + if id <= length(view.ir.ir.types) + return view.ir.ir.types[id] + end + id -= length(view.ir.ir.types) + if id <= length(view.ir.ir.new_nodes) + return view.ir.ir.new_nodes[id].typ + end + id -= length(view.ir.ir.new_nodes) + return view.ir.pending_nodes[id].typ +end + +function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue) + @assert idx.id < compact.result_idx + (compact.result[idx.id] === v) && return + # Kill count for current uses + for ops in userefs(compact.result[idx.id]) + val = ops[] + if isa(val, SSAValue) + @assert compact.used_ssas[val.id] >= 1 + compact.used_ssas[val.id] -= 1 + end + end + compact.result[idx.id] = v + # Add count for new use + if count_added_node!(compact, v) + push!(compact.late_fixup, idx.id) + end end function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int) if idx < compact.result_idx - (compact.result[idx] === v) && return - # Kill count for current uses - for ops in userefs(compact.result[idx]) - val = ops[] - if isa(val, SSAValue) - @assert compact.used_ssas[val.id] >= 1 - compact.used_ssas[val.id] -= 1 - end - end - compact.result[idx] = v - # Add count for new use - count_added_node!(compact, v) + compact[SSAValue(idx)] = v else compact.ir.stmts[idx] = v end @@ -509,10 +621,14 @@ end function getindex(view::TypesView, idx) isa(idx, SSAValue) && (idx = idx.id) - ir = view.ir - if isa(ir, IncrementalCompact) - if idx < ir.result_idx - return ir.result_types[idx] + if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx + return view.ir.result_types[idx] + else + ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir + if idx <= length(ir.types) + return ir.types[idx] + else + return ir.new_nodes[idx - length(ir.types)].typ end ir = ir.ir end @@ -528,47 +644,103 @@ function done(compact::IncrementalCompact, (idx, _a)::Tuple{Int, Int}) return idx > length(compact.ir.stmts) && (compact.new_nodes_idx > length(compact.perm)) end +function getindex(view::TypesView, idx::NewSSAValue) + @assert isa(view.ir, IncrementalCompact) + compact = view.ir + compact.new_new_nodes[idx.id].typ +end + function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int}, processed_idx::Int, result_idx::Int, - ssa_rename::Vector{Any}, used_ssas::Vector{Int}) + ssa_rename::Vector{Any}, used_ssas::Vector{Int}, + do_rename_ssa::Bool) values = Vector{Any}(undef, length(old_values)) for i = 1:length(old_values) isassigned(old_values, i) || continue val = old_values[i] if isa(val, SSAValue) + if do_rename_ssa + if val.id > processed_idx + push!(late_fixup, result_idx) + val = OldSSAValue(val.id) + else + val = renumber_ssa2(val, ssa_rename, used_ssas, do_rename_ssa) + end + else + used_ssas[val.id] += 1 + end + elseif isa(val, OldSSAValue) if val.id > processed_idx push!(late_fixup, result_idx) - val = OldSSAValue(val.id) else - val = renumber_ssa!(val, ssa_rename, true, used_ssas) + # Always renumber these. do_rename_ssa applies only to actual SSAValues + val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, true) end + elseif isa(val, NewSSAValue) + push!(late_fixup, result_idx) end values[i] = val end return values end +function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssa::Vector{Int}, do_rename_ssa::Bool) + id = val.id + if id > length(ssanums) + return val + end + if do_rename_ssa + val = ssanums[id] + end + if isa(val, SSAValue) && used_ssa !== nothing + used_ssa[val.id] += 1 + end + return val +end + +function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssa::Vector{Int}, late_fixup::Vector{Int}, result_idx::Int, do_rename_ssa::Bool) + urs = userefs(stmt) + for op in urs + val = op[] + if isa(val, OldSSAValue) || isa(val, NewSSAValue) + push!(late_fixup, result_idx) + end + if isa(val, SSAValue) + val = renumber_ssa2(val, ssanums, used_ssa, do_rename_ssa) + end + if isa(val, OldSSAValue) || isa(val, NewSSAValue) + push!(late_fixup, result_idx) + end + op[] = val + end + return urs[] +end + function process_node!(result::Vector{Any}, result_idx::Int, ssa_rename::Vector{Any}, late_fixup::Vector{Int}, used_ssas::Vector{Int}, @nospecialize(stmt), - idx::Int, processed_idx::Int) + idx::Int, processed_idx::Int, do_rename_ssa::Bool) ssa_rename[idx] = SSAValue(result_idx) if stmt === nothing ssa_rename[idx] = stmt + elseif isa(stmt, OldSSAValue) + ssa_rename[idx] = ssa_rename[stmt.id] elseif isa(stmt, GotoNode) || isa(stmt, GlobalRef) result[result_idx] = stmt result_idx += 1 elseif isa(stmt, Expr) || isa(stmt, PiNode) || isa(stmt, GotoIfNot) || isa(stmt, ReturnNode) || isa(stmt, UpsilonNode) - result[result_idx] = renumber_ssa!(stmt, ssa_rename, true, used_ssas) + result[result_idx] = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa) result_idx += 1 elseif isa(stmt, PhiNode) - result[result_idx] = PhiNode(stmt.edges, process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas)) + result[result_idx] = PhiNode(stmt.edges, process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)) result_idx += 1 elseif isa(stmt, PhiCNode) - result[result_idx] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas)) + result[result_idx] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)) result_idx += 1 elseif isa(stmt, SSAValue) # identity assign, replace uses of this ssa value with its result - stmt = ssa_rename[stmt.id] + if do_rename_ssa + stmt = ssa_rename[stmt.id] + end ssa_rename[idx] = stmt else # Constant assign, replace uses of this ssa value with its result @@ -576,9 +748,10 @@ function process_node!(result::Vector{Any}, result_idx::Int, ssa_rename::Vector{ end return result_idx end -function process_node!(compact::IncrementalCompact, result_idx::Int, @nospecialize(stmt), idx::Int, processed_idx::Int) +function process_node!(compact::IncrementalCompact, result_idx::Int, @nospecialize(stmt), idx::Int, processed_idx::Int, do_rename_ssa::Bool) return process_node!(compact.result, result_idx, compact.ssa_rename, - compact.late_fixup, compact.used_ssas, stmt, idx, processed_idx) + compact.late_fixup, compact.used_ssas, stmt, idx, processed_idx, + do_rename_ssa) end function resize!(compact::IncrementalCompact, nnewnodes) @@ -621,12 +794,12 @@ function attach_after_stmt_after(compact::IncrementalCompact, idx::Int) entry.pos == idx && entry.attach_after end -function process_newnode!(compact, new_idx, new_node_entry, idx, active_bb) +function process_newnode!(compact, new_idx, new_node_entry, idx, active_bb, do_rename_ssa) old_result_idx = compact.result_idx bb = compact.ir.cfg.blocks[active_bb] compact.result_types[old_result_idx] = new_node_entry.typ compact.result_lines[old_result_idx] = new_node_entry.line - result_idx = process_node!(compact, old_result_idx, new_node_entry.node, new_idx, idx) + result_idx = process_node!(compact, old_result_idx, new_node_entry.node, new_idx, idx, do_rename_ssa) compact.result_idx = result_idx # If this instruction has reverse affinity and we were at the end of a basic block, # finish it now. @@ -651,14 +824,21 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}) compact.new_nodes_idx += 1 new_node_entry = compact.ir.new_nodes[new_idx] new_idx += length(compact.ir.stmts) - return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb) + return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb, true) + elseif !isempty(compact.pending_perm) && + (entry = compact.pending_nodes[compact.pending_perm[1]]; + entry.attach_after ? entry.pos == idx - 1 : entry.pos == idx) + new_idx = popfirst!(compact.pending_perm) + new_node_entry = compact.pending_nodes[new_idx] + new_idx += length(compact.ir.stmts) + length(compact.ir.new_nodes) + return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb, false) end # This will get overwritten in future iterations if # result_idx is not, incremented, but that's ok and expected compact.result_types[old_result_idx] = compact.ir.types[idx] compact.result_lines[old_result_idx] = compact.ir.lines[idx] compact.result_flags[old_result_idx] = compact.ir.flags[idx] - result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx) + result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx, true) stmt_if_any = old_result_idx == result_idx ? nothing : compact.result[old_result_idx] compact.result_idx = result_idx if idx == last(bb.stmts) && !attach_after_stmt_after(compact, idx) @@ -673,22 +853,29 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}) return Pair{Int, Any}(old_result_idx, compact.result[old_result_idx]), (compact.idx, active_bb) end -function maybe_erase_unused!(extra_worklist, compact, idx) - effect_free = stmt_effect_free(compact.result[idx], compact, compact.ir.mod) +function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing) + stmt = compact.result[idx] + stmt === nothing && return false + effect_free = stmt_effect_free(stmt, compact, compact.ir.mod) if effect_free - for ops in userefs(compact.result[idx]) + for ops in userefs(stmt) val = ops[] - if isa(val, SSAValue) + # If the pass we ran inserted new nodes, it's possible for those + # to be outside our used_ssas count. + if isa(val, SSAValue) && val.id <= length(compact.used_ssas) if compact.used_ssas[val.id] == 1 if val.id < idx push!(extra_worklist, val.id) end end compact.used_ssas[val.id] -= 1 + callback(val) end end compact.result[idx] = nothing + return true end + return false end function fixup_phinode_values!(compact, old_values) @@ -701,47 +888,85 @@ function fixup_phinode_values!(compact, old_values) if isa(val, SSAValue) compact.used_ssas[val.id] += 1 end + elseif isa(val, NewSSAValue) + val = SSAValue(length(compact.result) + val.id) end values[i] = val end values end +function fixup_node(compact, @nospecialize(stmt)) + if isa(stmt, PhiNode) + return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) + elseif isa(stmt, PhiCNode) + return PhiCNode(fixup_phinode_values!(compact, stmt.values)) + elseif isa(stmt, NewSSAValue) + return SSAValue(length(compact.result) + stmt.id) + else + urs = userefs(stmt) + urs === () && return stmt + for ur in urs + val = ur[] + if isa(val, NewSSAValue) + ur[] = SSAValue(length(compact.result) + val.id) + end + end + return urs[] + end +end + function just_fixup!(compact) for idx in compact.late_fixup stmt = compact.result[idx] - if isa(stmt, PhiNode) - compact.result[idx] = PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) - else - stmt = stmt::PhiCNode - compact.result[idx] = PhiCNode(fixup_phinode_values!(compact, stmt.values)) + new_stmt = fixup_node(compact, stmt) + (stmt !== new_stmt) && (compact.result[idx] = new_stmt) + end + for idx in 1:length(compact.new_new_nodes) + node = compact.new_new_nodes[idx] + new_stmt = fixup_node(compact, node.node) + if node.node !== new_stmt + compact.new_new_nodes[idx] = NewNode( + node.pos, node.attach_after, node.typ, + new_stmt, node.line) end end end -function finish(compact::IncrementalCompact) - just_fixup!(compact) - # Record this somewhere? - result_idx = compact.result_idx - resize!(compact.result, result_idx-1) - resize!(compact.result_types, result_idx-1) - resize!(compact.result_lines, result_idx-1) - resize!(compact.result_flags, result_idx-1) - bb = compact.result_bbs[end] - compact.result_bbs[end] = BasicBlock(bb, - StmtRange(first(bb.stmts), result_idx-1)) +function simple_dce!(compact) # Perform simple DCE for unused values extra_worklist = Int[] for (idx, nused) in Iterators.enumerate(compact.used_ssas) - idx >= result_idx && break + idx >= compact.result_idx && break nused == 0 || continue maybe_erase_unused!(extra_worklist, compact, idx) end while !isempty(extra_worklist) maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist)) end +end + +function non_dce_finish!(compact::IncrementalCompact) + result_idx = compact.result_idx + resize!(compact.result, result_idx-1) + resize!(compact.result_types, result_idx-1) + resize!(compact.result_lines, result_idx-1) + resize!(compact.result_flags, result_idx-1) + just_fixup!(compact) + bb = compact.result_bbs[end] + compact.result_bbs[end] = BasicBlock(bb, + StmtRange(first(bb.stmts), result_idx-1)) +end + +function finish(compact::IncrementalCompact) + non_dce_finish!(compact) + simple_dce!(compact) + complete(compact) +end + +function complete(compact) cfg = CFG(compact.result_bbs, Int[first(bb.stmts) for bb in compact.result_bbs[2:end]]) - return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, NewNode[]) + return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, compact.new_new_nodes) end function compact!(code::IRCode) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index a7e84e2a9486c..b3e32c436eeb2 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -101,73 +101,110 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks, du, phin end end -function walk_to_def(compact::IncrementalCompact, @nospecialize(def), intermediaries=IdSet{Int}(), allow_phinode::Bool=true, phi_locs=Tuple{Int, Int}[]) - if !isa(def, SSAValue) - return (def, 0) +function simple_walk(compact::IncrementalCompact, defssa::Union{SSAValue, NewSSAValue, OldSSAValue}, pi_callback=(pi,idx)->nothing) + while true + if isa(defssa, OldSSAValue) && already_inserted(compact, defssa) + rename = compact.ssa_rename[defssa.id] + if isa(rename, Union{SSAValue, OldSSAValue, NewSSAValue}) + defssa = rename + continue + end + return rename + end + def = compact[defssa] + if isa(def, PiNode) + pi_callback(def, defssa) + if isa(def.val, SSAValue) + if isa(defssa, OldSSAValue) && !already_inserted(compact, defssa) + defssa = OldSSAValue(def.val.id) + else + defssa = def.val + end + else + return def.val + end + elseif isa(def, Union{SSAValue, OldSSAValue, NewSSAValue}) + pi_callback(def, defssa) + defssa = def + elseif isa(def, Union{PhiNode, PhiCNode, Expr, GlobalRef}) + return defssa + else + return def + end end - orig_defidx = defidx = def.id +end + +function simple_walk_constraint(compact, defidx, typeconstraint = types(compact)[defidx]) + callback = (pi, _)->isa(pi, PiNode) && (typeconstraint = typeintersect(typeconstraint, pi.typ)) + def = simple_walk(compact, defidx, callback) + def, typeconstraint +end + +""" + walk_to_defs(compact, val, intermediaries) + +Starting at `val` walk use-def chains to get all the leaves feeding into +this val (pruning those leaves rules out by path conditions). +""" +function walk_to_defs(compact, defssa, typeconstraint, visited_phinodes=Any[]) # Step 2: Figure out what the struct is defined as - def = compact[defidx] - typeconstraint = types(compact)[defidx] + def = compact[defssa] ## Track definitions through PiNode/PhiNode found_def = false ## Track which PhiNodes, SSAValue intermediaries ## we forwarded through. - while true - if isa(def, PiNode) - push!(intermediaries, defidx) - typeconstraint = typeintersect(typeconstraint, def.typ) - if isa(def.val, SSAValue) - defidx = def.val.id - def = compact[defidx] - else - def = def.val - end - continue - elseif isa(def, FastForward) - append!(phi_locs, def.phi_locs) - def = def.to - elseif isa(def, PhiNode) - # For now, we don't track setfields structs through phi nodes - allow_phinode || break - push!(intermediaries, defidx) + visited = IdSet{Any}() + worklist = Tuple{Any, Any}[] + leaves = Any[] + push!(worklist, (defssa, typeconstraint)) + while !isempty(worklist) + defssa, typeconstraint = pop!(worklist) + push!(visited, defssa) + def = compact[defssa] + if isa(def, PhiNode) + push!(visited_phinodes, defssa) possible_predecessors = let def=def, typeconstraint=typeconstraint collect(Iterators.filter(1:length(def.edges)) do n isassigned(def.values, n) || return false - value = def.values[n] - edge_typ = widenconst(compact_exprtype(compact, value)) + val = def.values[n] + if isa(defssa, OldSSAValue) && isa(val, SSAValue) + val = OldSSAValue(val.id) + end + edge_typ = widenconst(compact_exprtype(compact, val)) return typeintersect(edge_typ, typeconstraint) !== Union{} end) end - # For now, only look at unique predecessors - if length(possible_predecessors) == 1 - n = possible_predecessors[1] + for n in possible_predecessors pred = def.edges[n] val = def.values[n] - if isa(val, SSAValue) - push!(phi_locs, (pred, defidx)) - defidx = val.id - def = compact[defidx] - elseif def == val + if isa(defssa, OldSSAValue) && isa(val, SSAValue) + val = OldSSAValue(val.id) + end + if isa(val, Union{SSAValue, OldSSAValue, NewSSAValue}) + new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint) + if isa(new_def, Union{SSAValue, OldSSAValue, NewSSAValue}) + if !(new_def in visited) + push!(worklist, (new_def, new_constraint)) + end + continue + end + val = new_def + end + if def == val # This shouldn't really ever happen, but # patterns like this can occur in dead code, # so bail out. break else - def = val + push!(leaves, val) end continue end - elseif isa(def, SSAValue) - push!(intermediaries, defidx) - defidx = def.id - def = compact[def.id] - continue + else + push!(leaves, defssa) end - found_def = true - break end - found_def ? (def, defidx) : nothing + leaves end function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr) @@ -178,24 +215,318 @@ function process_immutable_preserve(new_preserves::Vector{Any}, compact::Increme end end -struct FastForward - to::SSAValue - phi_locs::Vector{Tuple{Int, Int}} +function already_inserted(compact::IncrementalCompact, old::OldSSAValue) + id = old.id + if id < length(compact.ir.stmts) + return id < compact.idx + end + id -= length(compact.ir.stmts) + if id < length(compact.ir.new_nodes) + error() + end + id -= length(compact.ir.new_nodes) + @assert id <= length(compact.pending_nodes) + return !(id in compact.pending_perm) +end + +function is_pending(compact::IncrementalCompact, old::OldSSAValue) + return old.id > length(compact.ir.stmts) + length(compact.ir.new_nodes) +end + +function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt), + @nospecialize(result_t), field::Int, leaves::Vector{Any}) + # For every leaf, the lifted value + lifted_leaves = IdDict{Any, Any}() + maybe_undef = false + for leaf in leaves + leaf_key = leaf + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + if isa(leaf, OldSSAValue) && already_inserted(compact, leaf) + leaf = compact.ssa_rename[leaf.id] + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + leaf = simple_walk(compact, leaf) + end + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + def = compact[leaf] + else + def = leaf + end + else + def = compact[leaf] + end + if is_tuple_call(compact.ir, def) && isa(field, Int) && 1 <= field < length(def.args) + lifted = def.args[1+field] + if isa(leaf, OldSSAValue) && isa(lifted, SSAValue) + lifted = OldSSAValue(lifted.id) + end + lifted_leaves[leaf_key] = RefValue{Any}(lifted) + continue + elseif isexpr(def, :new) + typ = def.typ + if isa(typ, UnionAll) + typ = unwrap_unionall(typ) + end + (isa(typ, DataType) && (!typ.abstract)) || return nothing + @assert !typ.mutable + field = try_compute_fieldidx_expr(typ, stmt) + field === nothing && return nothing + if length(def.args) < 1 + field + ftyp = fieldtype(typ, field) + if !isbits(ftyp) + # On this branch, this will be a guaranteed UndefRefError. + # We use the regular undef mechanic to lift this to a boolean slot + maybe_undef = true + lifted_leaves[leaf_key] = nothing + continue + end + return nothing + # Expand the Expr(:new) to include it's element Expr(:new) nodes up until the one we want + compact[leaf] = nothing + for i = (length(def.args) + 1):(1+field) + ftyp = fieldtype(typ, i - 1) + isbits(ftyp) || return nothing + push!(def.args, insert_node!(compact, leaf, result_t, Expr(:new, ftyp))) + end + compact[leaf] = def + end + lifted = def.args[1+field] + if isa(leaf, OldSSAValue) && isa(lifted, SSAValue) + lifted = OldSSAValue(lifted.id) + end + lifted_leaves[leaf_key] = RefValue{Any}(lifted) + continue + else + typ = compact_exprtype(compact, leaf) + if !isa(typ, Const) + # If the leaf is an old ssa value, insert a getfield here + # We will revisit this getfield later when compaction gets + # to the appropriate point. + # N.B.: This can be a bit dangerous because it can lead to + # infinite loops if we accidentally insert a node just ahead + # of where we are + if isa(leaf, OldSSAValue) && (isa(field, Int) || isa(field, Symbol)) + (isa(typ, DataType) && (!typ.abstract)) || return nothing + @assert !typ.mutable + # If there's the potential for an undefref error on access, we cannot insert a getfield + if field > typ.ninitialized && !isbits(fieldtype(typ, field)) + return nothing + lifted_leaves[leaf] = RefValue{Any}(insert_node!(compact, leaf, make_MaybeUndef(result_t), Expr(:call, :unchecked_getfield, SSAValue(leaf.id), field), true)) + maybe_undef = true + else + return nothing + lifted_leaves[leaf] = RefValue{Any}(insert_node!(compact, leaf, result_t, Expr(:call, getfield, SSAValue(leaf.id), field), true)) + end + continue + end + return nothing + end + leaf = typ.val + # Fall through to below + end + elseif isa(leaf, QuoteNode) + leaf = leaf.value + elseif isa(leaf, Union{Argument, Expr}) + return nothing + end + isimmutable(leaf) || return nothing + isdefined(leaf, field) || return nothing + val = getfield(leaf, field) + is_inlineable_constant(val) || return nothing + lifted_leaves[leaf_key] = RefValue{Any}(quoted(val)) + end + lifted_leaves, maybe_undef +end + +make_MaybeUndef(typ) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ) + +const AnySSAValue = Union{SSAValue, OldSSAValue, NewSSAValue} + +function lift_comparison!(compact::IncrementalCompact, idx::Int, + @nospecialize(c1), @nospecialize(c2), stmt::Expr, + lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}) + if isa(c1, Const) + cmp = c1 + typeconstraint = widenconst(c2) + val = stmt.args[3] + else + cmp = c2 + typeconstraint = widenconst(c1) + val = stmt.args[2] + end + + is_type_only = isdefined(typeof(cmp), :instance) + + if isa(val, Union{OldSSAValue, SSAValue}) + val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) + end + + visited_phinodes = Any[] + if isa(val, Union{OldSSAValue, SSAValue, NewSSAValue}) && isa(compact[val], PhiNode) + leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes) + else + leaves = [val] + end + + # Let's check if we evaluate the comparison for each one of the leaves + lifted_leaves = IdDict{Any, Any}() + for leaf in leaves + r = egal_tfunc(compact_exprtype(compact, leaf), cmp) + if isa(r, Const) + lifted_leaves[leaf] = RefValue{Any}(r.val) + else + # TODO: In some cases it might be profitable to hoist the === + # here. + return + end + end + + lifted_val = perform_lifting!(compact, visited_phinodes, cmp, lifting_cache, Bool, lifted_leaves, val) + + #global assertion_counter + #assertion_counter::Int += 1 + #insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true) + #return + compact[idx] = lifted_val +end + +struct LiftedPhi + ssa::AnySSAValue + node::Any + need_argupdate::Bool end -function getfield_elim_pass!(ir::IRCode, domtree::DomTree) +function perform_lifting!(compact::IncrementalCompact, + visited_phinodes::Vector{Any}, @nospecialize(cache_key), + lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}, + @nospecialize(result_t), lifted_leaves::IdDict{Any, Any}, @nospecialize(stmt_val)) + reverse_mapping = IdDict{Any, Any}(ssa => id for (id, ssa) in enumerate(visited_phinodes)) + + # Insert PhiNodes + lifted_phis = LiftedPhi[] + for item in visited_phinodes + if (item, cache_key) in keys(lifting_cache) + ssa = lifting_cache[Pair{AnySSAValue, Any}(item, cache_key)] + push!(lifted_phis, LiftedPhi(ssa, compact[ssa], false)) + continue + end + n = PhiNode() + ssa = insert_node!(compact, item, result_t, n) + lifting_cache[Pair{AnySSAValue, Any}(item, cache_key)] = ssa + push!(lifted_phis, LiftedPhi(ssa, n, true)) + end + + # Fix up arguments + for (old_node_ssa, lf) in zip(visited_phinodes, lifted_phis) + old_node = compact[old_node_ssa] + new_node = lf.node + lf.need_argupdate || continue + for i = 1:length(old_node.edges) + edge = old_node.edges[i] + isassigned(old_node.values, i) || continue + val = old_node.values[i] + orig_val = val + if isa(old_node_ssa, OldSSAValue) && !is_pending(compact, old_node_ssa) && !already_inserted(compact, old_node_ssa) && isa(val, SSAValue) + val = OldSSAValue(val.id) + end + if isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) + val = simple_walk(compact, val) + end + if val in keys(lifted_leaves) + push!(new_node.edges, edge) + lifted_val = lifted_leaves[val] + if lifted_val === nothing + resize!(new_node.values, length(new_node.values)+1) + continue + end + lifted_val = lifted_val.x + if isa(lifted_val, Union{NewSSAValue, SSAValue, OldSSAValue}) + lifted_val = simple_walk(compact, lifted_val) + end + push!(new_node.values, lifted_val) + elseif isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) && val in keys(reverse_mapping) + push!(new_node.edges, edge) + push!(new_node.values, lifted_phis[reverse_mapping[val]].ssa) + else + # Probably ignored by path condition, skip this + end + end + end + + for lf in lifted_phis + count_added_node!(compact, lf.node) + end + + # Fixup the stmt itself + if isa(stmt_val, Union{SSAValue, OldSSAValue}) + stmt_val = simple_walk(compact, stmt_val) + end + if stmt_val in keys(lifted_leaves) + stmt_val = lifted_leaves[stmt_val] + @assert stmt_val !== nothing + stmt_val = stmt_val.x + else + isa(stmt_val, Union{SSAValue, OldSSAValue}) && stmt_val in keys(reverse_mapping) + stmt_val = lifted_phis[reverse_mapping[stmt_val]].ssa + end + + return stmt_val +end + +assertion_counter = 0 +function getfield_elim_pass!(ir::IRCode, domtree) compact = IncrementalCompact(ir) insertions = Vector{Any}() defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}() + lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + revisit_worklist = Int[] + #ndone, nmax = 0, 200 for (idx, stmt) in compact isa(stmt, Expr) || continue - is_getfield = false + #ndone >= nmax && continue + #ndone += 1 + result_t = compact_exprtype(compact, SSAValue(idx)) + is_getfield = is_setfield = false is_ccall = false + is_unchecked = false # Step 1: Check whether the statement we're looking at is a getfield/setfield! if is_known_call(stmt, setfield!, compact) is_setfield = true elseif is_known_call(stmt, getfield, compact) is_getfield = true + elseif is_known_call(stmt, isa, compact) + # TODO + continue + elseif is_known_call(stmt, typeassert, compact) + # Canonicalize + # X = typeassert(Y, T)::S + # into + # typeassert(Y, T) + # X = PiNode(Y, S) + # N.B.: Inference may have a more precise type for `S`, than + # just T, but from here on out, there's no problem with + # using just using that. + # so subsequent analysis only has to deal with the latter + # form. TODO: This isn't the best place to put this. + # Also, we should probably have a version of typeassert + # that's defined not to return its value to make life easier + # for the backend. + pi = insert_node_here!(compact, + PiNode(stmt.args[2], compact.result_types[idx]), compact.result_types[idx], + compact.result_lines[idx], true) + compact.ssa_rename[compact.idx-1] = pi + continue + elseif is_known_call(stmt, (===), compact) + c1 = compact_exprtype(compact, stmt.args[2]) + c2 = compact_exprtype(compact, stmt.args[3]) + if !(isa(c1, Const) || isa(c2, Const)) + continue + end + (isa(c1, Const) && isa(c2, Const)) && continue + lift_comparison!(compact, idx, c1, c2, stmt, lifting_cache) + continue + elseif isexpr(stmt, :call) && stmt.args[1] == :unchecked_getfield + is_getfield = true + is_unchecked = true elseif isexpr(stmt, :foreigncall) nccallargs = stmt.args[5] new_preserves = Any[] @@ -203,9 +534,10 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) for (pidx, preserved_arg) in enumerate(old_preserves) intermediaries = IdSet() isa(preserved_arg, SSAValue) || continue - def = walk_to_def(compact, preserved_arg, intermediaries, false) - def !== nothing || continue - (def, defidx) = def + def = simple_walk(compact, preserved_arg, (pi, ssa)->push!(intermediaries, ssa.id)) + isa(def, SSAValue) || continue + defidx = def.id + def = compact[defidx] if is_tuple_call(compact, def) process_immutable_preserve(new_preserves, compact, def) old_preserves[pidx] = nothing @@ -244,61 +576,78 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) isa(field, QuoteNode) && (field = field.value) isa(field, Union{Int, Symbol}) || continue - intermediaries = IdSet() - phi_locs = Tuple{Int, Int}[] - def = walk_to_def(compact, stmt.args[2], intermediaries, is_getfield, phi_locs) - def !== nothing || continue - (def, defidx) = def + struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2]))) + isa(struct_typ, DataType) || continue - if !is_getfield - (defidx == 0) && continue - mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse())) - push!(defuse.defs, idx) + def, typeconstraint = stmt.args[2], struct_typ + + if struct_typ.mutable + isa(def, SSAValue) || continue + intermediaries = IdSet() + def = simple_walk(compact, def, (pi, ssa)->push!(intermediaries, ssa.id)) + # Mutable stuff here + isa(def, SSAValue) || continue + mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse())) + if is_setfield + push!(defuse.defs, idx) + else + push!(defuse.uses, idx) + end union!(mid, intermediaries) continue + elseif is_setfield + continue end - # Step 3: Check if the definition we eventually end up at is either - # a tuple(...) call or Expr(:new) and perform replacement. - if is_tuple_call(compact, def) && isa(field, Int) && 1 <= field < length(def.args) - forwarded = def.args[1+field] - elseif isexpr(def, :new) - typ = def.typ - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) - end - isa(typ, DataType) || continue - if typ.mutable - @assert defidx != 0 - mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse())) - push!(defuse.uses, idx) - union!(mid, intermediaries) - continue - end - field = try_compute_fieldidx_expr(typ, stmt) - field === nothing && continue - forwarded = def.args[1+field] - else - obj = compact_exprtype(compact, def) - isa(obj, Const) || continue - obj = obj.val - isimmutable(obj) || continue - field = try_compute_fieldidx_expr(typeof(obj), stmt) - field === nothing && continue - isdefined(obj, field) || continue - val = getfield(obj, field) - is_inlineable_constant(val) || continue - forwarded = quoted(val) + + if isa(def, Union{OldSSAValue, SSAValue}) + def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint) end - # Step 4: Remember any phinodes we need to insert - if !isempty(phi_locs) && isa(forwarded, SSAValue) - # TODO: We have have to use BB ids for phi_locs - # to avoid index invalidation. - push!(insertions, idx) - compact[idx] = FastForward(forwarded, phi_locs) + + visited_phinodes = Any[] + if isa(def, Union{OldSSAValue, SSAValue, NewSSAValue}) && isa(compact[def], PhiNode) + leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes) else - compact[idx] = forwarded + leaves = Any[def] end + + isempty(leaves) && continue + + field = try_compute_fieldidx_expr(struct_typ, stmt) + field === nothing && continue + + r = lift_leaves(compact, stmt, result_t, field, leaves) + r === nothing && continue + lifted_leaves, any_undef = r + + if any_undef + result_t = make_MaybeUndef(result_t) + end + +# @Base.show result_t +# @Base.show stmt +# for (k,v) in lifted_leaves +# @Base.show (k, v) +# if isa(k, AnySSAValue) +# @Base.show compact[k] +# end +# if isa(v, RefValue) && isa(v.x, AnySSAValue) +# @Base.show compact[v.x] +# end +# end + val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2]) + + # Insert the undef check if necessary + if any_undef && !is_unchecked + insert_node!(compact, SSAValue(idx), Nothing, Expr(:undefcheck, :getfield, val)) + end + + global assertion_counter + assertion_counter::Int += 1 + #insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true) + #continue + compact[idx] = val end + ir = finish(compact) # Now go through any mutable structs and see which ones we can eliminate for (idx, (intermediaries, defuse)) in defuses @@ -389,25 +738,87 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) ir[SSAValue(use)] = new_expr end end - for idx in insertions - # For non-dominating load-store forward, we may have to insert extra phi nodes - # TODO: Can use the domtree to eliminate unnecessary phis, but ok for now - ff = ir.stmts[idx] - ff === nothing && continue # May have been DCE'd if there were no more uses - ff = ff::FastForward - forwarded = ff.to - if isa(forwarded, SSAValue) - forwarded_typ = ir.types[forwarded.id] - for (pred, pos) in reverse!(ff.phi_locs) - node = PhiNode() - push!(node.edges, pred) - push!(node.values, forwarded) - forwarded = insert_node!(ir, pos, forwarded_typ, node) + ir +end + +function adce_erase!(phi_uses, extra_worklist, compact, idx) + if isa(compact.result[idx], PhiNode) + maybe_erase_unused!(extra_worklist, compact, idx, val->phi_uses[val.id]-=1) + else + maybe_erase_unused!(extra_worklist, compact, idx) + end +end + +function count_uses(stmt, uses) + for ur in userefs(stmt) + if isa(ur[], SSAValue) + uses[ur[].id] += 1 + end + end +end + +function mark_phi_cycles(compact, safe_phis, phi) + worklist = Int[] + push!(worklist, phi) + while !isempty(worklist) + phi = pop!(worklist) + push!(safe_phis, phi) + for ur in userefs(compact.result[phi]) + val = ur[] + isa(val, SSAValue) || continue + isa(compact[val], PhiNode) || continue + (val.id in safe_phis) && continue + push!(worklist, val.id) + end + end +end + +function adce_pass!(ir) + phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) + all_phis = Int[] + compact = IncrementalCompact(ir) + for (idx, stmt) in compact + if isa(stmt, PhiNode) + push!(all_phis, idx) + end + end + non_dce_finish!(compact) + for phi in all_phis + count_uses(compact.result[phi], phi_uses) + end + # Perform simple DCE for unused values + extra_worklist = Int[] + for (idx, nused) in Iterators.enumerate(compact.used_ssas) + idx >= compact.result_idx && break + nused == 0 || continue + adce_erase!(phi_uses, extra_worklist, compact, idx) + end + while !isempty(extra_worklist) + adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + end + # Go back and erase any phi cycles + changed = true + while changed + changed = false + safe_phis = IdSet{Int}() + for phi in all_phis + # Save any phi cycles that have non-phi uses + if compact.used_ssas[phi] - phi_uses[phi] != 0 + mark_phi_cycles(compact, safe_phis, phi) + end + end + for phi in all_phis + if !(phi in safe_phis) + push!(extra_worklist, phi) + end + end + while !isempty(extra_worklist) + if adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + changed = true end end - ir.stmts[idx] = forwarded end - ir + complete(compact) end function type_lift_pass!(ir::IRCode) @@ -415,7 +826,8 @@ function type_lift_pass!(ir::IRCode) has_non_type_ctx_uses = IdSet{Int}() lifted_undef = IdDict{Int, Any}() for (idx, stmt) in pairs(ir.stmts) - if stmt isa Expr && (stmt.head === :isdefined || stmt.head === :undefcheck) + stmt isa Expr || continue + if (stmt.head === :isdefined || stmt.head === :undefcheck) val = (stmt.head === :isdefined) ? stmt.args[1] : stmt.args[2] # undef can only show up by being introduced in a phi # node (or an UpsilonNode() argument to a PhiC node), @@ -427,11 +839,11 @@ function type_lift_pass!(ir::IRCode) end continue end - worklist = Tuple{Int, Int, SSAValue, Int}[(val.id, 0, SSAValue(0), 0)] stmt_id = val.id while isa(ir.stmts[stmt_id], PiNode) stmt_id = ir.stmts[stmt_id].val.id end + worklist = Tuple{Int, Int, SSAValue, Int}[(stmt_id, 0, SSAValue(0), 0)] def = ir.stmts[stmt_id] if !isa(def, PhiNode) && !isa(def, PhiCNode) if stmt.head === :isdefined diff --git a/base/compiler/ssair/queries.jl b/base/compiler/ssair/queries.jl index 3ec426b43b931..9b3d58522d8b0 100644 --- a/base/compiler/ssair/queries.jl +++ b/base/compiler/ssair/queries.jl @@ -63,7 +63,7 @@ function abstract_eval_ssavalue(s::SSAValue, src::IncrementalCompact) end function compact_exprtype(compact::IncrementalCompact, @nospecialize(value)) - if isa(value, Union{SSAValue, OldSSAValue}) + if isa(value, Union{SSAValue, OldSSAValue, NewSSAValue}) return types(compact)[value] elseif isa(value, Argument) return compact.ir.argtypes[value.n] diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index d5322dad9b259..067c0f9e7a0c5 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -11,8 +11,20 @@ function Base.show(io::IO, cfg::CFG) end end -print_ssa(io::IO, val) = isa(val, SSAValue) ? Base.print(io, "%$(val.id)") : - isa(val, Argument) ? Base.print(io, "%%$(val.n)") : Base.print(io, val) +function print_ssa(io::IO, val) + if isa(val, SSAValue) + Base.print(io, "%$(val.id)") + elseif isa(val, Argument) + Base.print(io, "%%$(val.n)") + else + try + Base.print(io, val) + catch + Base.print(io, "") + end + end +end + function print_node(io::IO, idx, stmt, used, maxsize; color = true, print_typ=true) if idx in used pad = " "^(maxsize-length(string(idx))) @@ -97,7 +109,6 @@ function Base.show(io::IO, code::IRCode) maxused = maximum(used) maxsize = length(string(maxused)) end - for idx in eachindex(code.stmts) if !isassigned(code.stmts, idx) # This is invalid, but do something useful rather @@ -164,10 +175,19 @@ function Base.show(io::IO, code::IRCode) end typ = code.types[idx] print_ssa_typ = !isa(stmt, PiNode) && idx in used - print_node(io, idx, stmt, used, maxsize, - print_typ=!print_ssa_typ || (isa(stmt, Expr) && typ != stmt.typ)) + try + print_node(io, idx, stmt, used, maxsize, + print_typ=!print_ssa_typ || (isa(stmt, Expr) && typ != stmt.typ)) + catch + print(io, "") + end if print_ssa_typ - Base.printstyled(io, "::$(typ)", color=:red) + typ_str = try + string(typ) + catch + "" + end + Base.printstyled(io, "::$(typ_str)", color=:red) end Base.println(io) end diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index 36238b3530879..54399c39919f2 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -28,6 +28,10 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, error() end end + elseif isa(op, Union{OldSSAValue, NewSSAValue}) + #@Base.show ir + @verify_error "Left over SSA marker" + error() elseif isa(op, Union{SlotNumber, TypedSlot}) @verify_error "Left over slot detected in converted IR" error() @@ -109,7 +113,7 @@ function verify_ir(ir::IRCode) if isa(stmt, Expr) || isa(stmt, ReturnNode) # TODO: make sure everything has line info if !(stmt isa ReturnNode && !isdefined(stmt, :val)) # not actually a return node, but an unreachable marker if ir.lines[idx] <= 0 - @verify_error "Missing line number information for statement $idx of $ir" + #@verify_error "Missing line number information for statement $idx of $ir" end end end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 910dbfa3953fb..1f55b32089574 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -207,25 +207,25 @@ add_tfunc(ifelse, 3, 3, end return tmerge(x, y) end, 1) -add_tfunc(===, 2, 2, - function (@nospecialize(x), @nospecialize(y)) - if isa(x, Const) && isa(y, Const) - return Const(x.val === y.val) - elseif typeintersect(widenconst(x), widenconst(y)) === Bottom - return Const(false) - elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) || - (isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance)) - return Const(true) - elseif isa(x, Conditional) && isa(y, Const) - y.val === false && return Conditional(x.var, x.elsetype, x.vtype) - y.val === true && return x - return x - elseif isa(y, Conditional) && isa(x, Const) - x.val === false && return Conditional(y.var, y.elsetype, y.vtype) - x.val === true && return y - end - return Bool - end, 1) +function egal_tfunc(@nospecialize(x), @nospecialize(y)) + if isa(x, Const) && isa(y, Const) + return Const(x.val === y.val) + elseif typeintersect(widenconst(x), widenconst(y)) === Bottom + return Const(false) + elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) || + (isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance)) + return Const(true) + elseif isa(x, Conditional) && isa(y, Const) + y.val === false && return Conditional(x.var, x.elsetype, x.vtype) + y.val === true && return x + return x + elseif isa(y, Conditional) && isa(x, Const) + x.val === false && return Conditional(y.var, y.elsetype, y.vtype) + x.val === true && return y + end + return Bool +end +add_tfunc(===, 2, 2, egal_tfunc, 1) function isdefined_tfunc(@nospecialize(args...)) arg1 = args[1] if isa(arg1, Const) @@ -436,12 +436,46 @@ function const_datatype_getfield_tfunc(sv, fld) return nothing end +function fieldcount_noerror(@nospecialize t) + if t isa UnionAll || t isa Union + t = ccall(:jl_argument_datatype, Any, (Any,), t) + if t === nothing + return nothing + end + t = t::DataType + elseif t == Union{} + return 0 + end + if !(t isa DataType) + return nothing + end + if t.name === NamedTuple.body.body.name + names, types = t.parameters + if names isa Tuple + return length(names) + end + if types isa DataType && types <: Tuple + return fieldcount_noerror(types) + end + abstr = true + else + abstr = t.abstract || (t.name === Tuple.name && isvatuple(t)) + end + if abstr + return nothing + end + return length(t.types) +end + + function try_compute_fieldidx(@nospecialize(typ), @nospecialize(field)) if isa(field, Symbol) field = fieldindex(typ, field, false) field == 0 && return nothing elseif isa(field, Integer) - (1 <= field <= fieldcount(typ)) || return nothing + max_fields = fieldcount_noerror(typ) + max_fields === nothing && return nothing + (1 <= field <= max_fields) || return nothing else return nothing end diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 634d63a7a8083..b339a0e7b6614 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -3327,13 +3327,13 @@ f(x) = yt(x) (or (simple-atom? e) (symbol? e) (and (pair? e) (memq (car e) '(quote inert top core globalref outerref - slot static_parameter boundscheck copyast))))) + slot static_parameter boundscheck))))) (define (valid-ir-rvalue? lhs e) (or (ssavalue? lhs) (valid-ir-argument? e) (and (symbol? lhs) (pair? e) - (memq (car e) '(new the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin))))) + (memq (car e) '(new the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin copyast))))) (define (valid-ir-return? e) ;; returning lambda directly is needed for @generated @@ -3439,7 +3439,7 @@ f(x) = yt(x) (cdr lst)))) (simple? (every (lambda (x) (or (simple-atom? x) (symbol? x) (and (pair? x) - (memq (car x) '(quote inert top core globalref outerref copyast boundscheck))))) + (memq (car x) '(quote inert top core globalref outerref boundscheck))))) lst))) (let loop ((lst lst) (vals '())) @@ -3454,7 +3454,7 @@ f(x) = yt(x) (not (simple-atom? arg)) (not (simple-atom? aval)) (not (and (pair? arg) - (memq (car arg) '(& quote inert top core globalref outerref copyast boundscheck)))) + (memq (car arg) '(& quote inert top core globalref outerref boundscheck)))) (not (and (symbol? aval) ;; function args are immutable and always assigned (memq aval (lam:args lam)))) (not (and (symbol? arg)