From 774bc08e55235a32ac74315b13da255f6bef1d73 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 6 Sep 2023 18:51:47 +1200 Subject: [PATCH] Fix rewrite of +(x, *(y...)) into add_mul (#225) * Fix rewrite of +(x, *(y...)) into add_mul * Add test * Add more tests --- src/rewrite_generic.jl | 15 +++++++++++++++ test/rewrite_generic.jl | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 2c1a06d0..55dcf6f5 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -121,6 +121,21 @@ function _rewrite_generic(stack::Expr, expr::Expr) @assert length(expr.args) > 1 if length(expr.args) == 2 # +(arg) return _rewrite_generic(stack, expr.args[2]) + elseif length(expr.args) == 3 && _is_call(expr.args[3], :*) + # +(x, *(y...)) => add_mul(x, y...) + x, is_mutable = _rewrite_generic(stack, expr.args[2]) + rhs = if is_mutable + Expr(:call, operate!!, add_mul, x) + else + Expr(:call, operate, add_mul, x) + end + for i in 2:length(expr.args[3].args) + yi, _ = _rewrite_generic(stack, expr.args[3].args[i]) + push!(rhs.args, yi) + end + root = gensym() + push!(stack.args, :($root = $rhs)) + return root, true end return _rewrite_generic_to_nested_op(stack, expr, add_mul) elseif expr.args[1] == :- diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index 078d1b4d..d6d32ebb 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -324,6 +324,26 @@ function test_rewrite_kw_in_ref() return end +function test_rewrite_expression() + x = [1.2] + @test MA.@rewrite(x + 2 * x, move_factors_into_sums = false) == 3x + @test MA.@rewrite(x + *(2, x, 3), move_factors_into_sums = false) == + x + *(2, x, 3) + y = 1.2 + @test MA.@rewrite( + sum(y for i in 1:2) + 2y, + move_factors_into_sums = false + ) == 4y + @test MA.@rewrite( + sum(y for i in 1:2) + y * y, + move_factors_into_sums = false + ) == 2y + y^2 + @test MA.@rewrite(y + 2 * y, move_factors_into_sums = false) == 3y + @test MA.@rewrite(y + *(2, y, 3), move_factors_into_sums = false) == + y + *(2, y, 3) + return +end + end # module TestRewriteGeneric.runtests()