Skip to content

Commit

Permalink
handle nested try catches (#119)
Browse files Browse the repository at this point in the history
this is follow-up to the ssa conversion fix for try/catch blocks in the
case of nested try catches. we now add arguments to catch branches.
  • Loading branch information
Pangoraw authored Jan 1, 2024
1 parent 8f5f50e commit 499f197
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
40 changes: 20 additions & 20 deletions src/passes/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,22 @@ function ssa!(ir::IR)
function reaching(b, slot)
haskey(defs[b.id], slot) && return defs[b.id][slot]
b.id == 1 && return undef
x = defs[b.id][slot] = argument!(b, type = slot.type, insert = false)
x = add_slot_argument!(b, slot)
return x
end
function add_slot_argument!(b, slot)
x = argument!(b, type = slot.type, insert = false)
if !haskey(defs[b.id], slot)
defs[b.id][slot] = x
end

for pred in predecessors(b)
if pred.id < current
for br in branches(pred, b)
push!(br.args, reaching(pred, slot))
end
else
push!(get!(todo[pred.id], b.id, Slot[]), slot)
push!(get!(Vector{Slot}, todo[pred.id], b.id), slot)
end
end

Expand All @@ -232,49 +240,41 @@ function ssa!(ir::IR)
# defs[block(ir, cbr.v).id] contains the defs at the end of
# the block, so we use the cached defs in catch_branches instead.
for cbr in catch_branches[b.id]
# Get the definition for a slot from a given catch branch.
cbr_v = cbr.v
stmt = ir[cbr_v]
if haskey(cbr.defs, slot)
# Slot v was defined at catch branch
push!(stmt.expr.args, cbr.defs[slot])
else
# Find slot v definition from instruction cbr_v
b = block(ir, cbr_v)
if b.id == 1
cbr_b = block(ir, cbr_v)
if cbr_b.id == 1
push!(stmt.expr.args, undef)
continue
end

# there is already a def for this slot as an argument to the block
# but which was added after the catch branch.
if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable
bdef = defs[b.id][slot]
if haskey(defs[cbr_b.id], slot) && defs[cbr_b.id][slot] isa Variable
bdef = defs[cbr_b.id][slot]
(def_b, loc) = ir.defs[bdef.id]
if def_b == b.id && loc < 0
if def_b == cbr_b.id && loc < 0
push!(stmt.expr.args, bdef)
continue
end
end

# get the slot definition from each predecessors of the block owning the catch 'branch'
new_arg = defs[b.id][slot] = argument!(b; type=slot.type, insert=false)
# add an argument to cbr_b for slot
new_arg = add_slot_argument!(cbr_b, slot)
push!(stmt.expr.args, new_arg)
for pred in predecessors(b)
if pred.id < current
for br in branches(pred, b)
push!(br.args, reaching(pred, slot))
end
else
push!(get!(todo[pred.id], b.id, Slot[]), slot)
end
end
end
end
end

return x
end
function catchbranch!(v, slot = nothing)
function catchbranch!(v)
for handler in handlers
cbr = CatchBranch(copy(defs[current]), insertafter!(ir, v, Expr(:catch, handler)))
push!(get!(Vector{CatchBranch}, catch_branches, handler), cbr)
Expand All @@ -287,7 +287,7 @@ function ssa!(ir::IR)
ex = st.expr
if isexpr(ex, :(=)) && ex.args[1] isa Slot
defs[b.id][ex.args[1]] = rename(ex.args[2])
catchbranch!(v, ex.args[1])
catchbranch!(v)
delete!(ir, v)
elseif isexpr(ex, :enter)
push!(handlers, ex.args[1])
Expand Down
18 changes: 18 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,19 @@ function f_try_catch7()
x
end

function f_try_catch8(x)
local z = x
try
z = z + sqrt(x)
z = 1 - log(z)
z = log(z)
catch
z = abs(z)
finally
return z
end
end

@testset "try/catch" begin
ir = @code_ir f_try_catch(1.)
fir = func(ir)
Expand Down Expand Up @@ -358,4 +371,9 @@ end

ir = @code_ir f_try_catch7()
@test func(ir)(nothing) === 1.

ir = @code_ir f_try_catch8(1.)
fir = func(ir)
@test fir(nothing, 1.) == log(1. - log(2.))
@test fir(nothing, -1.) == 1.
end

0 comments on commit 499f197

Please sign in to comment.