From 499f1972e582c190fa007b9b2d005debd6b8d72c Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Mon, 1 Jan 2024 16:44:53 +0100 Subject: [PATCH] handle nested try catches (#119) this is follow-up to the ssa conversion fix for try/catch blocks in the case of nested try catches. we now add arguments to catch branches. --- src/passes/passes.jl | 40 ++++++++++++++++++++-------------------- test/compiler.jl | 18 ++++++++++++++++++ 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 1a8d236..87b3bcc 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -215,14 +215,22 @@ function ssa!(ir::IR) function reaching(b, slot) haskey(defs[b.id], slot) && return defs[b.id][slot] b.id == 1 && return undef - x = defs[b.id][slot] = argument!(b, type = slot.type, insert = false) + x = add_slot_argument!(b, slot) + return x + end + function add_slot_argument!(b, slot) + x = argument!(b, type = slot.type, insert = false) + if !haskey(defs[b.id], slot) + defs[b.id][slot] = x + end + for pred in predecessors(b) if pred.id < current for br in branches(pred, b) push!(br.args, reaching(pred, slot)) end else - push!(get!(todo[pred.id], b.id, Slot[]), slot) + push!(get!(Vector{Slot}, todo[pred.id], b.id), slot) end end @@ -232,6 +240,7 @@ function ssa!(ir::IR) # defs[block(ir, cbr.v).id] contains the defs at the end of # the block, so we use the cached defs in catch_branches instead. for cbr in catch_branches[b.id] + # Get the definition for a slot from a given catch branch. cbr_v = cbr.v stmt = ir[cbr_v] if haskey(cbr.defs, slot) @@ -239,42 +248,33 @@ function ssa!(ir::IR) push!(stmt.expr.args, cbr.defs[slot]) else # Find slot v definition from instruction cbr_v - b = block(ir, cbr_v) - if b.id == 1 + cbr_b = block(ir, cbr_v) + if cbr_b.id == 1 push!(stmt.expr.args, undef) continue end # there is already a def for this slot as an argument to the block # but which was added after the catch branch. - if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable - bdef = defs[b.id][slot] + if haskey(defs[cbr_b.id], slot) && defs[cbr_b.id][slot] isa Variable + bdef = defs[cbr_b.id][slot] (def_b, loc) = ir.defs[bdef.id] - if def_b == b.id && loc < 0 + if def_b == cbr_b.id && loc < 0 push!(stmt.expr.args, bdef) continue end end - # get the slot definition from each predecessors of the block owning the catch 'branch' - new_arg = defs[b.id][slot] = argument!(b; type=slot.type, insert=false) + # add an argument to cbr_b for slot + new_arg = add_slot_argument!(cbr_b, slot) push!(stmt.expr.args, new_arg) - for pred in predecessors(b) - if pred.id < current - for br in branches(pred, b) - push!(br.args, reaching(pred, slot)) - end - else - push!(get!(todo[pred.id], b.id, Slot[]), slot) - end - end end end end return x end - function catchbranch!(v, slot = nothing) + function catchbranch!(v) for handler in handlers cbr = CatchBranch(copy(defs[current]), insertafter!(ir, v, Expr(:catch, handler))) push!(get!(Vector{CatchBranch}, catch_branches, handler), cbr) @@ -287,7 +287,7 @@ function ssa!(ir::IR) ex = st.expr if isexpr(ex, :(=)) && ex.args[1] isa Slot defs[b.id][ex.args[1]] = rename(ex.args[2]) - catchbranch!(v, ex.args[1]) + catchbranch!(v) delete!(ir, v) elseif isexpr(ex, :enter) push!(handlers, ex.args[1]) diff --git a/test/compiler.jl b/test/compiler.jl index c89fa97..1fce761 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -306,6 +306,19 @@ function f_try_catch7() x end +function f_try_catch8(x) + local z = x + try + z = z + sqrt(x) + z = 1 - log(z) + z = log(z) + catch + z = abs(z) + finally + return z + end +end + @testset "try/catch" begin ir = @code_ir f_try_catch(1.) fir = func(ir) @@ -358,4 +371,9 @@ end ir = @code_ir f_try_catch7() @test func(ir)(nothing) === 1. + + ir = @code_ir f_try_catch8(1.) + fir = func(ir) + @test fir(nothing, 1.) == log(1. - log(2.)) + @test fir(nothing, -1.) == 1. end