Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurse rewrite_generic into :if and :block #303

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,57 @@
return Meta.isexpr(expr, :kw) && expr.args[1] == kwarg
end

function _rewrite_elseif!(if_expr, expr::Any)
push!(if_expr.args, esc(expr))
return false

Check warning on line 60 in src/rewrite_generic.jl

View check run for this annotation

Codecov / codecov/patch

src/rewrite_generic.jl#L58-L60

Added lines #L58 - L60 were not covered by tests
end
odow marked this conversation as resolved.
Show resolved Hide resolved

function _rewrite_elseif!(if_expr, expr::Expr)
if Meta.isexpr(expr, :elseif)
new_ifelse_expr = Expr(:elseif, esc(expr.args[1]))
push!(if_expr.args, new_ifelse_expr)
@assert 2 <= length(expr.args) <= 3
return mapreduce(&, 2:length(expr.args)) do i
return _rewrite_elseif!(new_ifelse_expr, expr.args[i])
end
else
stack = quote end
root, is_mutable = _rewrite_generic(stack, expr)
push!(stack.args, root)
push!(if_expr.args, stack)
return is_mutable
end
end

"""
_rewrite_generic(stack::Expr, expr::Expr)

This method is the heart of the rewrite logic. It converts `expr` into a mutable
equivalent.
"""
function _rewrite_generic(stack::Expr, expr::Expr)
if !Meta.isexpr(expr, :call)
if Meta.isexpr(expr, :block)
new_stack = quote end
for arg in expr.args
root, _ = _rewrite_generic(new_stack, arg)
push!(new_stack.args, root)
end
root = gensym()
push!(stack.args, :($root = $new_stack))
return root, false
elseif Meta.isexpr(expr, :if)
# If blocks are special, because we can't lift the computation inside
odow marked this conversation as resolved.
Show resolved Hide resolved
# them into the stack; the values might be defined only if the branch is
# true.
if_expr = Expr(:if, esc(expr.args[1]))
@assert 2 <= length(expr.args) <= 3
is_mutable = mapreduce(&, 2:length(expr.args)) do i
return _rewrite_elseif!(if_expr, expr.args[i])
end
root = gensym()
push!(stack.args, :($root = $if_expr))
return root, is_mutable
elseif !Meta.isexpr(expr, :call)
# In situations like `x[i]`, we do not attempt to rewrite. Return `expr`
# and don't let future callers mutate.
return esc(expr), false
Expand Down
73 changes: 73 additions & 0 deletions test/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,79 @@ function test_rewrite_generic_sum_dims()
return
end

function test_rewrite_block()
@test_rewrite begin
x = 1
y = x + 2
z = 3 * y
end
@test_rewrite begin
x = [1]
y = x + [2]
z = 3 * y
end
return
end

function test_rewrite_ifelse()
@test_rewrite begin
x = -1
y = [3.0]
if x < 1
y .+ x
else
2 * x
end
end
@test_rewrite begin
x = 2
y = [3.0]
if x < 1
y .+ x
else
2 * x
end
end
@test_rewrite begin
x = 2
y = [3.0, 4.0]
if x < 1
y .+ x
elseif length(y) == 2
0.0
else
2 * x
end
end
@test_rewrite begin
x = 2
y = Float64[]
if x < 1
y .+ x
elseif length(y) == 2
0.0
elseif isempty(y)
-1.0
else
2 * x
end
end
@test_rewrite begin
x = 2
y = Float64[1.0]
if x < 1
1.0
elseif length(y) == 2
2.0
elseif isempty(y)
3.0
else
4.0
end
end
return
end

end # module

TestRewriteGeneric.runtests()
Loading