diff --git a/base/Base.jl b/base/Base.jl index 0ca13265adc4f..0f0f720a12413 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -483,6 +483,8 @@ include("errorshow.jl") include("initdefs.jl") Filesystem.__postinit__() +include("compilerwrappers.jl") + # worker threads include("threadcall.jl") diff --git a/base/compiler/ssair/basicblock.jl b/base/compiler/ssair/basicblock.jl index 427aae707e664..938098d16487f 100644 --- a/base/compiler/ssair/basicblock.jl +++ b/base/compiler/ssair/basicblock.jl @@ -29,4 +29,9 @@ function BasicBlock(old_bb, stmts) return BasicBlock(stmts, old_bb.preds, old_bb.succs) end +==(a::BasicBlock, b::BasicBlock) = + a.stmts === b.stmts && a.preds == b.preds && a.succs == b.succs +# Note: comparing `.stmts` using `===` instead of `==` since the equivalence class for +# vectors is too coarse when `stmts.stop < stmts.start`. + copy(bb::BasicBlock) = BasicBlock(bb.stmts, copy(bb.preds), copy(bb.succs)) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 289f643a84f3a..dc25bfa608abf 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -11,6 +11,7 @@ struct CFG end copy(c::CFG) = CFG(BasicBlock[copy(b) for b in c.blocks], copy(c.index)) +==(a::CFG, b::CFG) = a.blocks == b.blocks && a.index == b.index function cfg_insert_edge!(cfg::CFG, from::Int, to::Int) # Assumes that this edge does not already exist @@ -28,6 +29,14 @@ function cfg_delete_edge!(cfg::CFG, from::Int, to::Int) nothing end +function cfg_reindex!(cfg::CFG) + resize!(cfg.index, length(cfg.blocks) - 1) + for ibb in 2:length(cfg.blocks) + cfg.index[ibb-1] = first(cfg.blocks[ibb].stmts) + end + return cfg +end + function bb_ordering() lt = (<=) by = x::BasicBlock -> first(x.stmts) @@ -251,6 +260,7 @@ function resize!(stmts::InstructionStream, len) end return stmts end +iterate(is::InstructionStream, st::Int=1) = (st <= length(is)) ? (is[st], st + 1) : nothing struct Instruction data::InstructionStream @@ -2024,3 +2034,501 @@ struct InsertBefore{T<:Union{IRCode, IncrementalCompact}} <: Inserter pos::SSAValue end (i::InsertBefore)(newinst::NewInstruction) = insert_node!(i.src, i.pos, newinst) + +### +### CFG manipulation tools +### + +""" + NewBlocksInfo + +Information on basic blocks newly allocated in `allocate_new_blocks!`. See +[`allocate_new_blocks!`](@ref) for explanation on the properties. +""" +struct NewBlocksInfo + positions_nblocks::Vector{Pair{Int,Int}} + block_to_range::IdDict{Int,UnitRange{Int}} + ssachangemap::Vector{Int} + bbchangemap::Vector{Int} +end + +""" + allocate_new_blocks!( + ir::IRCode, + positions_nblocks::Vector{Pair{Int,Int}}, + ) -> info::NewBlocksInfo + +For each `position => nblocks` in `positions_nblocks`, add new "singleton" `nblocks` blocks +(i.e., each BB contains a single dummy instruction) before the statement `position`. + +The caller must ensure that: + + statement_positions = map(first, positions_nblocks) + @assert all(>(0), diff(statement_positions)) + @assert all(1 <= p <= length(ir.stmts) for p in statement_positions) + @assert all(nblocks > 0 for (_, nblocks) in positions_nblocks) + +Note that this function does not wire up the CFG for newly created BBs. It just +inserts the dummy `GotoNode(0)` at the end of the new singleton BBs and the BB +_before_ (in terms of `ir.cfg.bocks`) it. The predecessors of the BB just +before the newly added singleton BB and the successors of the BB just after the +newly added singleton BB are re-wired. See `allocate_goto_sequence!` for +an example for creating a valid CFG. + +For example, given an `ir` containing: + + #bb + %1 = instruction_1 + %2 = instruction_2 + +`allocate_new_blocks!(ir, [2 => 1])` produces + + #bb′ + %1 = instruction_1 + goto #0 # dummy + #new_bb_1 + goto #0 # dummy + #new_bb_2 + %2 = instruction_2 + +The predecessors of `#bb′` are equivalent to the predecessors of `#bb`. The successors of +`#new_bb_2` are equivalent to the successors of `#bb`. + +The returned object `info::NewBlocksInfo` has the following properties: + +* `positions_nblocks`: The second argument. +* `block_to_range`: A mapping from an old basic block index to the indices of + `positions_nblocks` that contains the statements of the old basic block. +* `ssachangemap`: Given an original statement position `iold`, the new statement position is + `ssachangemap[iold]`. +* `bbchangemap`: Given an original basic block id `iold`, the new basic block ID is + `bbchangemap[iold]`. + +Functions [`split_blocks`](@ref), [`split_positions`](@ref), and +[`inserted_blocks`](@ref) can be used to iterate over the newly added basic blocks. These +functions support the following access patterns: + +```julia +blocks::NewBlocksInfo +for sb in split_blocks(blocks) # iterate over split BBs + sb.oldbb::Int + sb.indices::UnitRange{Int} + + for sp in split_positions(sb) # iterate over split positions + sp.oldbb::Int; @assert sp.oldbb == sb.oldbb + sp.prebb::Int # BB before the split position (exclusive) + sp.postbb::Int # BB after the split position (inclusive) + sp.index::Int; @assert sp.index ∈ sb.indices + + position, nblocks = blocks.positions_nblocks[sp.index] + for ib in inserted_blocks(sp) # iterate over newly added BBs + ib.oldbb::Int; @assert ib.oldbb == sb.oldbb + ib.newbb::Int; @assert sp.prebb < ib.newbb < sp.postbb + ib.index::Int; @assert ib.index == ib.index + ib.nth::Int; @assert ib.nth ∈ 1:nblocks + end + end +end +``` +""" +function allocate_new_blocks!(ir::IRCode, positions_nblocks::Vector{Pair{Int,Int}}) + @assert isincreasing(positions_nblocks; by = first) + ssachangemap = Vector{Int}(undef, length(ir.stmts) + length(ir.new_nodes.stmts)) + let iold = 1, inew = 1 + for (pos, nblocks) in positions_nblocks + @assert 1 <= pos <= length(ir.stmts) + @assert nblocks >= 0 + for i in iold:pos-1 + ssachangemap[i] = inew + inew += 1 + end + inew += 1 + nblocks + iold = pos + end + for i in iold:length(ssachangemap) + ssachangemap[i] = inew + inew += 1 + end + end + + # For the original basic block index `ibb`, the pairs of position and the number of + # blocks to be inserted can be obtained by: + # indices = block_to_range[ibb]::UnitRange + # positions_nblocks[indices] + block_to_range = IdDict{Int,UnitRange{Int}}() + + # Two maps are used for relabeling BBs: + # * `bbchangemap` maps each old BB index to the index of the BB that includes the *last* + # statement in the old BB. + # * `gotolabelchangemap` maps each old BB index to the index of the BB that includes + # *first* statement in the old BB; i.e., it is used for fixing the labels in the + # goto-like nodes. + bbchangemap = ones(Int, length(ir.cfg.blocks)) + gotolabelchangemap = ones(Int, length(ir.cfg.blocks) + 1) # "+ 1" simplifies the code + newblocks = 0 + + let pre_index = 0, pre_ibb = 0 + for (i, (ipos, nblocks)) in pairs(positions_nblocks) + ibb = block_for_inst(ir.cfg, ipos) + + if pre_ibb != ibb + if pre_ibb != 0 + block_to_range[pre_ibb] = pre_index:i-1 + end + pre_index = i + pre_ibb = ibb + end + + bbchangemap[ibb] += 1 + nblocks + gotolabelchangemap[ibb+1] += 1 + nblocks + newblocks += 1 + nblocks + end + block_to_range[pre_ibb] = pre_index:length(positions_nblocks) + end + _cumsum!(bbchangemap) + _cumsum!(gotolabelchangemap) + + # Insert `newblocks` new blocks: + oldnblocks = length(ir.cfg.blocks) + resize!(ir.cfg.blocks, oldnblocks + newblocks) + # Copy pre-existing blocks: + for iold in oldnblocks:-1:1 + bb = ir.cfg.blocks[iold] + for labels in (bb.preds, bb.succs) + for (i, l) in pairs(labels) + labels[i] = bbchangemap[l] + end + # Note: Some labels that are referring to the split BBs are still incorrect + # at this point. These are copied to the new BBs in the next phase (and thus + # relabeling here is still required). + end + start = ssachangemap[bb.stmts.start] + stop = ssachangemap[bb.stmts.stop] + ir.cfg.blocks[bbchangemap[iold]] = BasicBlock(bb, StmtRange(start, stop)) + end + # Insert new blocks: + for (iold, indices) in block_to_range + ilst = bbchangemap[iold] # using bbchangemap as it's already moved + bblst = ir.cfg.blocks[ilst] + + # Assign `StmtRange`s to the new BBs (edges are handled later) + prefirst = bblst.stmts.start # already moved + inew = get(bbchangemap, iold - 1, 0) + 1 + local oldpos + for i in indices + oldpos, nblocks = positions_nblocks[i] + p = get(ssachangemap, oldpos - 1, 0) + 1 + ir.cfg.blocks[inew] = BasicBlock(StmtRange(min(prefirst, p), p)) + inew += 1 + p += 1 + for _ in 1:nblocks + ir.cfg.blocks[inew] = BasicBlock(StmtRange(p, p)) + inew += 1 + p += 1 + end + @assert p == ssachangemap[oldpos] + prefirst = p + end + + # Handle edges of the "head" and "tail" BBs + ifst = get(bbchangemap, iold - 1, 0) + 1 + bbfst = ir.cfg.blocks[ifst] + for p in bblst.preds + k = findfirst(==(ilst), ir.cfg.blocks[p].succs) + @assert k !== nothing + ir.cfg.blocks[p].succs[k] = ifst + end + copy!(bbfst.preds, bblst.preds) + empty!(bblst.preds) + stmts = StmtRange(ssachangemap[oldpos], last(bblst.stmts)) + ir.cfg.blocks[bbchangemap[iold]] = BasicBlock(bblst, stmts) + @assert !isempty(stmts) + end + for bb in ir.cfg.blocks + @assert !isempty(bb.stmts) + end + cfg_reindex!(ir.cfg) + + on_ssavalue(v) = SSAValue(ssachangemap[v.id]) + on_phi_label(l) = bbchangemap[l] + on_goto_label(l) = gotolabelchangemap[l] + for stmts in (ir.stmts, ir.new_nodes.stmts) + for i in 1:length(stmts) + st = stmts[i] + inst = ssamap(on_ssavalue, st[:inst]) + if inst isa PhiNode + edges = inst.edges::Vector{Int32} + for i in 1:length(edges) + edges[i] = on_phi_label(edges[i]) + end + elseif inst isa GotoNode + inst = GotoNode(on_goto_label(inst.label)) + elseif inst isa GotoIfNot + inst = GotoIfNot(inst.cond, on_goto_label(inst.dest)) + elseif isexpr(inst, :enter) + inst.args[1] = on_goto_label(inst.args[1]::Int) + end + st[:inst] = inst + end + end + minpos, _ = positions_nblocks[1] # it's sorted + for (i, info) in pairs(ir.new_nodes.info) + if info.pos >= minpos + ir.new_nodes.info[i] = if info.attach_after + NewNodeInfo(ssachangemap[info.pos], info.attach_after) + else + NewNodeInfo(get(ssachangemap, info.pos - 1, 0) + 1, info.attach_after) + end + end + end + + # Fixup `ir.linetable` before mutating `ir.stmts.lines`: + linetablechangemap = Vector{Int32}(undef, length(ir.linetable)) + fill!(linetablechangemap, 1) + let lines = ir.stmts.line + # Allocate spaces for newly inserted statements + for (pos, nblocks) in positions_nblocks + linetablechangemap[lines[pos]] += 1 + nblocks + end + end + _cumsum!(linetablechangemap) + let newlength = linetablechangemap[end], ilast = newlength + 1 + @assert newlength == length(ir.linetable) + newblocks + resize!(ir.linetable, newlength) + for iold in length(linetablechangemap):-1:1 + inew = linetablechangemap[iold] + oldinfo = ir.linetable[iold] + inlined_at = oldinfo.inlined_at + if inlined_at != 0 + inlined_at = linetablechangemap[inlined_at] + end + newinfo = LineInfoNode( + oldinfo.module, + oldinfo.method, + oldinfo.file, + oldinfo.line, + inlined_at, + ) + for i in inew:ilast-1 + ir.linetable[i] = newinfo + end + ilast = inew + end + end + + # Fixup `ir.stmts.line` + let lines = ir.stmts.line, iold = length(lines), inew = iold + newblocks + + resize!(lines, inew) + for i in length(positions_nblocks):-1:1 + pos, nblocks = positions_nblocks[i] + while pos <= iold + lines[inew] = linetablechangemap[lines[iold]] + iold -= 1 + inew -= 1 + end + for _ in 1:1+nblocks + lines[inew] = linetablechangemap[lines[iold+1]] + inew -= 1 + end + end + @assert inew == iold + end + + # Fixup `ir.new_nodes.stmts.line` + let lines = ir.new_nodes.stmts.line + for i in 1:length(lines) + lines[i] = linetablechangemap[lines[i]] + end + end + + function allocate_stmts!(xs, filler) + n = length(xs) + resize!(xs, length(xs) + newblocks) + for i in n:-1:1 + xs[ssachangemap[i]] = xs[i] + end + for i in 2:n + for j in ssachangemap[i-1]+1:ssachangemap[i]-1 + xs[j] = filler + end + end + for js in (1:ssachangemap[1]-1, ssachangemap[end]+1:length(xs)) + for j in js + xs[j] = filler + end + end + end + + allocate_stmts!(ir.stmts.stmt, GotoNode(0)) # dummy + allocate_stmts!(ir.stmts.type, Any) + allocate_stmts!(ir.stmts.info, NoCallInfo()) + allocate_stmts!(ir.stmts.flag, 0) + + return NewBlocksInfo(positions_nblocks, block_to_range, ssachangemap, bbchangemap) +end + +""" + split_blocks(blocks::NewBlocksInfo) + +Iterate over old basic blocks that are split. + +Each element `sb::SplitBlock` of the iterable returned from `split_blocks` has the following +properties: + +* `blocks::NewBlocksInfo` +* `oldbb::Int`: Old index of a BB that is split. +* `indices`: For each `i` in `indices`, `blocks.positions_nblocks[i]` is the pair + `position => nblocks` that specifies that `nblocks` new blocks are inserted at statement + `position`. + +Use `split_positions(sb::SplitBlock)` to iterate over the statement positions at which +the old basic blocks are split. +""" +function split_blocks end + +""" + split_positions(sb::SplitBlock) + split_positions(blocks::NewBlocksInfo) + +Iterate over the statement positions at which the old basic blocks are split. + +Each element `sp::SplitPosition` of the iterable returned from `split_positions` has the +following properties: + +* `blocks::NewBlocksInfo` +* `index`: `blocks.positions_nblocks[index]` is the pair `position => nblocks` that + specifies that `nblocks` new blocks are inserted at statement `position`. +* `oldbb::Int`: Old index of a BB that is split. +* `prebb::Int`: New index of the BB before the split position (exclusive). +* `postbb::Int`: New index of the BB after the split position (inclusive). + +Use `inserted_blocks(sp::SplitPosition)` to iterate over the newly added "singleton" +basic blocks. +""" +function split_positions end + +""" + inserted_blocks(sp::SplitPosition) + inserted_blocks(sb::SplitBlock) + inserted_blocks(blocks::NewBlocksInfo) + +Iterate over the newly added basic blocks. + +Each element `ib::InsertedBlock` of the iterable returned from `inserted_blocks` has the +following properties: + +* `blocks::NewBlocksInfo` +* `index`: `blocks.positions_nblocks[index]` is the pair `position => nblocks` that + specifies that `nblocks` new blocks are inserted at statement `position`. +* `oldbb::Int`: Old index of a BB that is split. +* `newbb::Int`: New index of this BB. +* `nth::Int`: This BB is the `nth` BB at this split position. +""" +function inserted_blocks end + +struct SplitBlock + blocks::NewBlocksInfo + oldbb::Int + indices::UnitRange{Int} +end + +struct SplitPosition + blocks::NewBlocksInfo + oldbb::Int + prebb::Int + postbb::Int + index::Int +end + +struct InsertedBlock + blocks::NewBlocksInfo + oldbb::Int + newbb::Int + index::Int + nth::Int +end + +new_head_bb(sb::SplitBlock) = get(sb.blocks.bbchangemap, sb.oldbb - 1, 0) + 1 +new_tail_bb(sb::SplitBlock) = sb.blocks.bbchangemap[sb.oldbb] + +function split_blocks(blocks::NewBlocksInfo) + oldblocks = sort!(collect(Int, keys(blocks.block_to_range))) + Iterators.map(oldblocks) do oldbb + indices = blocks.block_to_range[oldbb] + SplitBlock(blocks, oldbb, indices) + end +end + +struct SplitPositionIterator + split::SplitBlock +end + +function iterate( + it::SplitPositionIterator, + (index, prevbb) = (first(it.split.indices), new_head_bb(it.split)), +) + (; blocks, oldbb, indices) = it.split + index > last(indices) && return nothing + _pos, nblocks = blocks.positions_nblocks[index] + postbb = prevbb + 1 + nblocks + sp = SplitPosition(blocks, oldbb, prevbb, postbb, index) + return (sp, (index + 1, postbb)) +end + +split_positions(sb::SplitBlock) = SplitPositionIterator(sb) +split_positions(blocks::NewBlocksInfo) = + Iterators.flatmap(split_positions, split_blocks(blocks)) + +function inserted_blocks(sp::SplitPosition) + (; blocks, index) = sp + _pos, nblocks = blocks.positions_nblocks[index] + Iterators.map(1:nblocks) do nth + InsertedBlock(sp.blocks, sp.oldbb, sp.prebb + nth, sp.index, nth) + end +end + +inserted_blocks(x) = Iterators.flatmap(inserted_blocks, split_positions(x)) + +""" + allocate_goto_sequence!(ir::IRCode, positions_nblocks) -> info::NewBlocksInfo + +Add new BBs using `allocate_new_blocks!(ir, positions_nblocks)` and then connect them by +"no-op" `GotoNode` that jumps to the next BB. Unlike `allocate_new_blocks!`, this function +results in an IR with valid CFG. + +Read `allocate_new_blocks!` on the preconditions on `positions_nblocks`. + +For example, given an `ir` containing: + + #bb + %1 = instruction_1 + %2 = instruction_2 + +`allocate_new_blocks!(ir, [2 => 1])` produces + + #bb′ + %1 = instruction_1 + goto #new_bb_1 + #new_bb_1 + goto #new_bb_2 + #new_bb_2 + %2 = instruction_2 +""" +function allocate_goto_sequence!(ir::IRCode, positions_nblocks) + blocks = allocate_new_blocks!(ir, positions_nblocks) + function set_goto(ibb1::Int) + ibb2 = ibb1 + 1 + b1 = ir.cfg.blocks[ibb1] + @assert ir.stmts.stmt[last(b1.stmts)] === GotoNode(0) + ir.stmts.stmt[last(b1.stmts)] = GotoNode(ibb2) + cfg_insert_edge!(ir.cfg, ibb1, ibb2) + end + for sp in split_positions(blocks) + set_goto(sp.prebb) + for block in inserted_blocks(sp) + set_goto(block.newbb) + end + end + return blocks +end diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index d8ca4d9551656..2c976a71f22d9 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -42,6 +42,29 @@ end anymap(f::Function, a::Array{Any,1}) = Any[ f(a[i]) for i in 1:length(a) ] +function _cumsum!(ys) + isempty(ys) && return ys + acc = ys[1] + for i in 2:length(ys) + acc += ys[i] + ys[i] = acc + end + return ys +end + +function isincreasing(xs; by = identity) + y = iterate(xs) + y === nothing && return true + x1, state = y + while true + y = iterate(xs, state) + y === nothing && return true + x2, state = y + isless(by(x1), by(x2)) || return false + x1 = x2 + end +end + ########### # scoping # ########### diff --git a/base/compilerwrappers.jl b/base/compilerwrappers.jl new file mode 100644 index 0000000000000..c74716804f51f --- /dev/null +++ b/base/compilerwrappers.jl @@ -0,0 +1,53 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +_qualifiedname(a, b::Symbol) = :($a.$b) +function _qualifiedname(a, b::Expr) + @assert isexpr(b, :., 2) + s = (b.args[2]::QuoteNode).value::Symbol + return _qualifiedname(_qualifiedname(a, b.args[1]), s) +end + +for T in [:BasicBlock, :CFG, :IRCode] + @eval copy(x::Core.Compiler.$T) = Core.Compiler.copy(x) +end + +for T in [:BasicBlock, :CFG] + @eval ==(a::Core.Compiler.$T, b::Core.Compiler.$T) = Core.Compiler.:(==)(a, b) +end + +for name in [ + :InstructionStream, + :UseRefIterator, + :DominatedBlocks, + :SplitPositionIterator, + :AbstractDict, + :AbstractSet, + :ValueIterator, + :Generator, + :(Iterators.Filter), + :(Iterators.Flatten), +] + T = _qualifiedname(:(Core.Compiler), name) + @eval iterate(xs::$T) = Core.Compiler.iterate(xs) + @eval iterate(xs::$T, state) = Core.Compiler.iterate(xs, state) + @eval IteratorSize(::Type{<:$T}) = SizeUnknown() +end + +for T in [ + :IRCode, + :InstructionStream, + :Instruction, + :StmtRange, + :UseRef, + :UnitRange, + :AbstractDict, +] + @eval getindex(xs::Core.Compiler.$T, args...) = Core.Compiler.getindex(xs, args...) + @eval setindex!(xs::Core.Compiler.$T, x, args...) = + Core.Compiler.setindex!(xs, x, args...) + @eval size(xs::Core.Compiler.$T) = Core.Compiler.size(xs) + @eval length(xs::Core.Compiler.$T) = Core.Compiler.length(xs) +end + +first(r::Core.Compiler.StmtRange) = Core.Compiler.first(r) +last(r::Core.Compiler.StmtRange) = Core.Compiler.last(r) diff --git a/base/show.jl b/base/show.jl index eb9f7bcece49d..1a104546fbbea 100644 --- a/base/show.jl +++ b/base/show.jl @@ -2804,15 +2804,6 @@ module IRShow import .Compiler: IRCode, CFG, scan_ssa_use!, isexpr, compute_basic_blocks, block_for_inst, IncrementalCompact, Effects, ALWAYS_TRUE, ALWAYS_FALSE - Base.getindex(r::Compiler.StmtRange, ind::Integer) = Compiler.getindex(r, ind) - Base.size(r::Compiler.StmtRange) = Compiler.size(r) - Base.first(r::Compiler.StmtRange) = Compiler.first(r) - Base.last(r::Compiler.StmtRange) = Compiler.last(r) - Base.length(is::Compiler.InstructionStream) = Compiler.length(is) - Base.iterate(is::Compiler.InstructionStream, st::Int=1) = (st <= Compiler.length(is)) ? (is[st], st + 1) : nothing - Base.getindex(is::Compiler.InstructionStream, idx::Int) = Compiler.getindex(is, idx) - Base.getindex(node::Compiler.Instruction, fld::Symbol) = Compiler.getindex(node, fld) - Base.getindex(ir::IRCode, ssa::SSAValue) = Compiler.getindex(ir, ssa) include("compiler/ssair/show.jl") const __debuginfo = Dict{Symbol, Any}( diff --git a/test/compiler/ssair.jl b/test/compiler/ssair.jl index 491638f7596d9..f5efdd3be7450 100644 --- a/test/compiler/ssair.jl +++ b/test/compiler/ssair.jl @@ -19,6 +19,11 @@ function make_ci(code) return ci end +function verify_ircode(ir) + Compiler.verify_ir(ir) + Compiler.verify_linetable(ir.linetable) +end + # TODO: this test is broken #let code = Any[ # GotoIfNot(SlotNumber(2), 4), @@ -685,3 +690,316 @@ end end end end + +### +### CFG manipulation tools +### + +function allocate_branches!(ir::Compiler.IRCode, positions_nbranches) + blocks = Core.Compiler.allocate_goto_sequence!( + ir, + [p => 2n for (p, n) in positions_nbranches], + ) + for sp in Core.Compiler.split_positions(blocks) + for (n, block) in enumerate(Compiler.inserted_blocks(sp)) + if isodd(n) + ibb1 = block.newbb + ibb3 = ibb1 + 2 + b1 = ir.cfg.blocks[ibb1] + ir.stmts.stmt[last(b1.stmts)] = GotoIfNot(false, ibb3) + Core.Compiler.cfg_insert_edge!(ir.cfg, ibb1, ibb3) + end + end + end + return blocks +end + +inserted_block_ranges(info) = [sp.prebb:sp.postbb for sp in Compiler.split_positions(info)] + +""" + inlineinfo(ir::IRCode, line::Integer) + +Extract inlining information at `line` that does not depend on `ir.linetable`. +""" +inlineinfo(ir, line) = + map(Base.IRShow.compute_loc_stack(ir.linetable, Int32(line))) do i + node = ir.linetable[i] + (; node.method, node.file, node.line) + end + +""" + check_linetable(ir, ir0, info) + +Test `ir.linetable` invariances of `allocate_new_blocks!` where the arguments are used as in + +```julia +ir = copy(ir0) +info = Compiler.allocate_new_blocks!(ir, ...) +``` + +or some equivalent code. +""" +function check_linetable(ir, ir0, info) + (; positions_nblocks) = info + function splabel((; index)) + origpos, _ = positions_nblocks[index] + "Statement $origpos (= first(positions_nblocks[$index]))" + end + iblabel((; nth)) = "$nth-th inserted block at this split point" + @testset "Goto nodes reflect original statement lines" begin + @testset "$(splabel(sp))" for sp in Compiler.split_positions(info) + origpos, _ = positions_nblocks[sp.index] + moved = ir.stmts[first(ir.cfg.blocks[sp.postbb].stmts)] + + @testset "Moved statement has the same inline info stack" begin + orig = ir0.stmts[origpos][:line] + @test inlineinfo(ir, moved[:line]) == inlineinfo(ir0, orig) + end + + @testset "Pre-split block" begin + goto = ir.stmts[last(ir.cfg.blocks[sp.prebb].stmts)] + @test goto[:line] == moved[:line] + end + + @testset "$(iblabel(ib))" for ib in Compiler.inserted_blocks(sp) + goto = ir.stmts[last(ir.cfg.blocks[ib.newbb].stmts)] + @test goto[:line] == moved[:line] + end + end + end +end + +function single_block(x) + x+2x +end + +#= +Input: + + #1 + %1 = $inst1 _ + %2 = $inst2 `-- split before %2 + return %2 + +Output: + + #1 + %1 = $inst1 + goto #2 + #2 + %3 = $inst2 + return %3 +=# +@testset "Split a block in two" begin + ir0, _ = only(Base.code_ircode(single_block, (Float64,), optimize_until = "compact 1")) + @test length(ir0.stmts) == 3 + + ir = copy(ir0) + info = Compiler.allocate_goto_sequence!(ir, [2 => 0]) + verify_ircode(ir) + @test inserted_block_ranges(info) == [1:2] + @test ir.cfg == CFG( + [ + BasicBlock(Compiler.StmtRange(1, 2), Int[], [2]), + BasicBlock(Compiler.StmtRange(3, 4), [1], Int[]), + ], + [3], + ) + b1, _ = ir.cfg.blocks + @test ir.stmts[last(b1.stmts)][:inst] == GotoNode(2) + check_linetable(ir, ir0, info) +end + +#= +Input: + + #1 + %1 = $inst1 _ + %2 = $inst2 `-- split before %2 and insert two blocks + return %2 + +Output: + + #1 + %1 = $inst1 + goto #2 + #2 + goto #4 if not false + #3 + goto #4 + #4 + %5 = $inst2 + return %5 +=# +@testset "Add one branch (two new blocks) to a single-block IR" begin + ir0, _ = only(Base.code_ircode(single_block, (Float64,), optimize_until = "compact 1")) + @test length(ir0.stmts) == 3 + + ir = copy(ir0) + info = allocate_branches!(ir, [2 => 1]) + # TODO: Access to undef in linetable + # verify_ircode(ir) + @test inserted_block_ranges(info) == [1:4] + @test ir.cfg == CFG( + [ + BasicBlock(Compiler.StmtRange(1, 2), Int[], [2]) + BasicBlock(Compiler.StmtRange(3, 3), [1], [3, 4]) + BasicBlock(Compiler.StmtRange(4, 4), [2], [4]) + BasicBlock(Compiler.StmtRange(5, 6), [3, 2], Int[]) + ], + [3, 4, 5], + ) + (b1, b2, b3, _) = ir.cfg.blocks + @test ir.stmts.stmt[last(b1.stmts)] == GotoNode(2) + @test ir.stmts.stmt[last(b2.stmts)] == GotoIfNot(false, 4) + @test ir.stmts.stmt[last(b3.stmts)] == GotoNode(4) + check_linetable(ir, ir0, info) +end + +#= +Input: + + #1 _ + %1 = $inst1 `-- split before %1 and insert one block + goto #2 + #2 + %3 = $inst2 _ + return %3 `-- split before %4 (`return %3`) and insert one block + +Output: + + #1 + goto #2 + #2 + goto #3 + #3 + %1 = $inst1 + goto #4 + #4 + %5 = $inst2 + goto #5 + #5 + goto #6 + #6 + return %5 + +This transformation is testing inserting multiple basic blocks at once. It also tests that +inserting at boundary locations work. +=# +@testset "Insert two more blocks to a two-block IR" begin + ir0, _ = only(Base.code_ircode(single_block, (Float64,), optimize_until = "compact 1")) + @test length(ir0.stmts) == 3 + + @testset "Split a block in two" begin + info = Compiler.allocate_goto_sequence!(ir0, [2 => 0]) + verify_ircode(ir0) + @test inserted_block_ranges(info) == [1:2] + end + + ir = copy(ir0) + info = Compiler.allocate_goto_sequence!(ir, [1 => 1, 4 => 1]) + @test length(ir.stmts) == 8 + @test inserted_block_ranges(info) == [1:3, 4:6] + verify_ircode(ir) + @test ir.cfg == CFG( + [ + BasicBlock(Compiler.StmtRange(1, 1), Int[], [2]) + BasicBlock(Compiler.StmtRange(2, 2), [1], [3]) + BasicBlock(Compiler.StmtRange(3, 4), [2], [4]) + BasicBlock(Compiler.StmtRange(5, 6), [3], [5]) + BasicBlock(Compiler.StmtRange(7, 7), [4], [6]) + BasicBlock(Compiler.StmtRange(8, 8), [5], Int[]) + ], + [2, 3, 5, 7, 8], + ) + @test [ir.stmts.stmt[last(b.stmts)] for b in ir.cfg.blocks[1:end-1]] == GotoNode.(2:6) + check_linetable(ir, ir0, info) +end + +#= +Input: + + #1 + %1 = $inst1 + %3 = new_instruction() _ + %2 = $inst2 `-- split before %2 + return %2 + +Output: + + #1 + %1 = $inst1 + %2 = new_instruction() # in the pre-split-point BB + goto #2 + #2 + %4 = $inst2 + return %4 +=# +@testset "Split a block of a pre-compact IR (attach before)" begin + ir0, _ = only(Base.code_ircode(single_block, (Float64,), optimize_until = "compact 1")) + @test length(ir0.stmts) == 3 + + st = Expr(:call, :new_instruction) + Compiler.insert_node!(ir0, 2, Compiler.NewInstruction(st, Any)) + + ir = copy(ir0) + info = allocate_branches!(ir, [2 => 0]) + @test inserted_block_ranges(info) == [1:2] + verify_ircode(ir) + check_linetable(ir, ir0, info) + + ir = Core.Compiler.compact!(ir) + verify_ircode(ir) + @test ir.cfg == CFG( + [ + BasicBlock(Compiler.StmtRange(1, 3), Int[], [2]) + BasicBlock(Compiler.StmtRange(4, 5), [1], Int[]) + ], + [4], + ) + @test ir.stmts[2][:inst] == st +end + +#= +Input: + + #1 + %1 = $inst1 _ + %2 = $inst2 `-- split before %2 + %3 = new_instruction() + return %2 + +Output: + + #1 + %1 = $inst1 + goto #2 + #2 + %3 = $inst2 + %4 = new_instruction() # in the post-split-point BB + return %3 +=# +@testset "Split a block of a pre-compact IR (attach after)" begin + ir0, _ = only(Base.code_ircode(single_block, (Float64,), optimize_until = "compact 1")) + @test length(ir0.stmts) == 3 + + st = Expr(:call, :new_instruction) + Compiler.insert_node!(ir0, 2, Compiler.NewInstruction(st, Any), true) + + ir = copy(ir0) + info = allocate_branches!(ir, [2 => 0]) + @test inserted_block_ranges(info) == [1:2] + verify_ircode(ir) + check_linetable(ir, ir0, info) + + ir = Core.Compiler.compact!(ir) + verify_ircode(ir) + @test ir.cfg == CFG( + [ + BasicBlock(Compiler.StmtRange(1, 2), Int[], [2]) + BasicBlock(Compiler.StmtRange(3, 5), [1], Int[]) + ], + [3], + ) + @test ir.stmts[4][:inst] == st +end