Skip to content

Commit

Permalink
fix ssa conversion for catch blocks (#117)
Browse files Browse the repository at this point in the history
* fix ssa conversion for catch blocks

the slotused function was not enough to account for all slot used by the
catch block and its successors. With this change, ssa conversion keeps
a live list of all catch 'branch' instructions and fetches the reaching
definitions for slots at the location of these :catch instructions.

* add more tests

* Use right type in default

* fix indent && more tests

* Update passes.jl

* add test_broken for UndefVarError

* use test_broken for 1.6
  • Loading branch information
Pangoraw authored Dec 13, 2023
1 parent 55c315a commit 8f5f50e
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 23 deletions.
79 changes: 60 additions & 19 deletions src/passes/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,41 +201,83 @@ function prune!(ir::IR)
return ir
end

function slotsused(bl)
slots = []
walk(ex) = prewalk(x -> (x isa Slot && !(x in slots) && push!(slots, x); x), ex)
for (v, st) in bl
ex = st.expr
isexpr(ex, :(=)) ? walk(ex.args[2]) : walk(ex)
end
return slots
struct CatchBranch
defs::Dict{Slot,Any}
v::Variable
end

function ssa!(ir::IR)
current = 1
defs = Dict(b => Dict{Slot,Any}() for b in 1:length(ir.blocks))
todo = Dict(b => Dict{Int,Vector{Slot}}() for b in 1:length(ir.blocks))
catches = Dict()
handlers = []
function reaching(b, v)
haskey(defs[b.id], v) && return defs[b.id][v]
catch_branches = Dict{Int,Vector{CatchBranch}}()
handlers = Int[]
function reaching(b, slot)
haskey(defs[b.id], slot) && return defs[b.id][slot]
b.id == 1 && return undef
x = defs[b.id][v] = argument!(b, type = v.type, insert = false)
x = defs[b.id][slot] = argument!(b, type = slot.type, insert = false)
for pred in predecessors(b)
if pred.id < current
for br in branches(pred, b)
push!(br.args, reaching(pred, v))
push!(br.args, reaching(pred, slot))
end
else
push!(get!(todo[pred.id], b.id, Slot[]), v)
push!(get!(todo[pred.id], b.id, Slot[]), slot)
end
end

if haskey(catch_branches, b.id)
# for each 'catch' branch to this catch block (catch block has `length(predecessors(b)) == 0`),
# we try to find the dominating definition for slot v.
# defs[block(ir, cbr.v).id] contains the defs at the end of
# the block, so we use the cached defs in catch_branches instead.
for cbr in catch_branches[b.id]
cbr_v = cbr.v
stmt = ir[cbr_v]
if haskey(cbr.defs, slot)
# Slot v was defined at catch branch
push!(stmt.expr.args, cbr.defs[slot])
else
# Find slot v definition from instruction cbr_v
b = block(ir, cbr_v)
if b.id == 1
push!(stmt.expr.args, undef)
continue
end

# there is already a def for this slot as an argument to the block
# but which was added after the catch branch.
if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable
bdef = defs[b.id][slot]
(def_b, loc) = ir.defs[bdef.id]
if def_b == b.id && loc < 0
push!(stmt.expr.args, bdef)
continue
end
end

# get the slot definition from each predecessors of the block owning the catch 'branch'
new_arg = defs[b.id][slot] = argument!(b; type=slot.type, insert=false)
push!(stmt.expr.args, new_arg)
for pred in predecessors(b)
if pred.id < current
for br in branches(pred, b)
push!(br.args, reaching(pred, slot))
end
else
push!(get!(todo[pred.id], b.id, Slot[]), slot)
end
end
end
end
end

return x
end
function catchbranch!(v, slot = nothing)
for handler in handlers
args = reaching.((block(ir, v),), catches[handler])
insertafter!(ir, v, Expr(:catch, handler, args...))
cbr = CatchBranch(copy(defs[current]), insertafter!(ir, v, Expr(:catch, handler)))
push!(get!(Vector{CatchBranch}, catch_branches, handler), cbr)
end
end
for b in blocks(ir)
Expand All @@ -248,10 +290,9 @@ function ssa!(ir::IR)
catchbranch!(v, ex.args[1])
delete!(ir, v)
elseif isexpr(ex, :enter)
catches[ex.args[1]] = slotsused(block(ir, ex.args[1]))
push!(handlers, ex.args[1])
catchbranch!(v)
elseif isexpr(ex, :leave) && !haskey(catches, current)
elseif isexpr(ex, :leave) && !haskey(catch_branches, current)
pop!(handlers)
else
ir[v] = rename(ex)
Expand Down
138 changes: 134 additions & 4 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function err3(f)
end

@test passthrough(err3, () -> 2+2) == 4
@test_broken passthrough(err3, () -> 0//0) == 1
@test passthrough(err3, () -> 0//0) == 1

@dynamo function mullify(a...)
ir = IR(a...)
Expand Down Expand Up @@ -222,10 +222,140 @@ function f_try_catch(x)
y
end

function f_try_catch2(x, cond)
local y
if cond
y = 2x
end

try
x = 3 * error()
catch
end

y
end

function f_try_catch3()
local x
try
error()
catch
x = 42
end
x
end

function f_try_catch4(x, cond)
local y
try
throw(x)
catch err
if cond
y = err + x
end
end
y
end

function f_try_catch5(x, cond)
local y
cond && (x = 2x)
try
y = x
cond && error()
catch
y = x + 1
end
y
end

function f_try_catch6(cond, y)
x = 1

if cond
y = 10y
else
y = 10y
end

try
cond && error()
catch
y = 2x
end

y+x
end

function f_try_catch7()
local x = 1.

for _ in 1:10

try
x = sqrt(x)
x -= 1.
catch
x = -x
end

x = x ^ 2
end

x
end

@testset "try/catch" begin
ir = @code_ir f_try_catch(1.)
@test true
fir = func(ir)
@test fir(nothing,1.) == 1.
@test_broken fir(nothing,-1.) == 1.
@test fir(nothing,1.) === 1.
@test fir(nothing,-1.) === 0.

ir = @code_ir f_try_catch2(1., false)
fir = func(ir)

# This should be @test_throws UndefVarError fir(nothing,42,false)
# See TODO in `IRTools.slots!`
@test_broken try
fir(nothing,42,false)
false
catch e
e isa UndefVarError
end
@test fir(nothing, 42, false) === IRTools.undef
@test fir(nothing, 42, true) === 84

ir = @code_ir f_try_catch3()
@test all(ir) do (_, stmt)
!IRTools.isexpr(stmt.expr, :catch) ||
length(stmt.expr.args) == 1
end
fir = func(ir)
@test fir(nothing) == 42

ir = @code_ir f_try_catch4(42, false)
fir = func(ir)
# This should be @test_throws UndefVarError fir(nothing,42,false)
@test_broken try
fir(nothing, 42, false)
false
catch e
e isa UndefVarError
end
@test fir(nothing, 42, false) === IRTools.undef
@test fir(nothing, 42, true) === 84

ir = @code_ir f_try_catch5(1, false)
fir = func(ir)
@test fir(nothing, 3, false) === 3
@test fir(nothing, 3, true) === 7

ir = @code_ir f_try_catch6(true, 1)
fir = func(ir)
@test fir(nothing, true, 1) === 3
@test fir(nothing, false, 1) === 11

ir = @code_ir f_try_catch7()
@test func(ir)(nothing) === 1.
end

0 comments on commit 8f5f50e

Please sign in to comment.