diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 08012ad..f07b7bd 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -55,6 +55,24 @@ function _is_kwarg(expr, kwarg::Symbol) return Meta.isexpr(expr, :kw) && expr.args[1] == kwarg end +function _rewrite_elseif!(if_expr, expr::Any) + if expr isa Expr && Meta.isexpr(expr, :elseif) + new_ifelse_expr = Expr(:elseif, esc(expr.args[1])) + push!(if_expr.args, new_ifelse_expr) + @assert 2 <= length(expr.args) <= 3 + return mapreduce(&, 2:length(expr.args)) do i + return _rewrite_elseif!(new_ifelse_expr, expr.args[i]) + end + end + stack = quote end + root, is_mutable = _rewrite_generic(stack, expr) + push!(stack.args, root) + push!(if_expr.args, stack) + return is_mutable +end + +_rewrite_generic(stack::Expr, expr::LineNumberNode) = expr, false + """ _rewrite_generic(stack::Expr, expr::Expr) @@ -62,7 +80,28 @@ This method is the heart of the rewrite logic. It converts `expr` into a mutable equivalent. """ function _rewrite_generic(stack::Expr, expr::Expr) - if !Meta.isexpr(expr, :call) + if Meta.isexpr(expr, :block) + new_stack = quote end + for arg in expr.args + root, _ = _rewrite_generic(new_stack, arg) + push!(new_stack.args, root) + end + root = gensym() + push!(stack.args, :($root = $new_stack)) + return root, false + elseif Meta.isexpr(expr, :if) + # `if` blocks are special, because we can't lift the computation inside + # them into the stack; the values might be defined only if the branch is + # true. + if_expr = Expr(:if, esc(expr.args[1])) + @assert 2 <= length(expr.args) <= 3 + is_mutable = mapreduce(&, 2:length(expr.args)) do i + return _rewrite_elseif!(if_expr, expr.args[i]) + end + root = gensym() + push!(stack.args, :($root = $if_expr)) + return root, is_mutable + elseif !Meta.isexpr(expr, :call) # In situations like `x[i]`, we do not attempt to rewrite. Return `expr` # and don't let future callers mutate. return esc(expr), false diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index 0b752a7..05fe615 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -360,6 +360,79 @@ function test_rewrite_generic_sum_dims() return end +function test_rewrite_block() + @test_rewrite begin + x = 1 + y = x + 2 + z = 3 * y + end + @test_rewrite begin + x = [1] + y = x + [2] + z = 3 * y + end + return +end + +function test_rewrite_ifelse() + @test_rewrite begin + x = -1 + y = [3.0] + if x < 1 + y .+ x + else + 2 * x + end + end + @test_rewrite begin + x = 2 + y = [3.0] + if x < 1 + y .+ x + else + 2 * x + end + end + @test_rewrite begin + x = 2 + y = [3.0, 4.0] + if x < 1 + y .+ x + elseif length(y) == 2 + 0.0 + else + 2 * x + end + end + @test_rewrite begin + x = 2 + y = Float64[] + if x < 1 + y .+ x + elseif length(y) == 2 + 0.0 + elseif isempty(y) + -1.0 + else + 2 * x + end + end + @test_rewrite begin + x = 2 + y = Float64[1.0] + if x < 1 + 1.0 + elseif length(y) == 2 + 2.0 + elseif isempty(y) + 3.0 + else + 4.0 + end + end + return +end + end # module TestRewriteGeneric.runtests()