Skip to content

Commit

Permalink
Fix rewrite of +(x, *(y...)) into add_mul (#225)
Browse files Browse the repository at this point in the history
* Fix rewrite of +(x, *(y...)) into add_mul

* Add test

* Add more tests
  • Loading branch information
odow authored Sep 6, 2023
1 parent 01cae64 commit 774bc08
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ function _rewrite_generic(stack::Expr, expr::Expr)
@assert length(expr.args) > 1
if length(expr.args) == 2 # +(arg)
return _rewrite_generic(stack, expr.args[2])
elseif length(expr.args) == 3 && _is_call(expr.args[3], :*)
# +(x, *(y...)) => add_mul(x, y...)
x, is_mutable = _rewrite_generic(stack, expr.args[2])
rhs = if is_mutable
Expr(:call, operate!!, add_mul, x)
else
Expr(:call, operate, add_mul, x)
end
for i in 2:length(expr.args[3].args)
yi, _ = _rewrite_generic(stack, expr.args[3].args[i])
push!(rhs.args, yi)
end
root = gensym()
push!(stack.args, :($root = $rhs))
return root, true
end
return _rewrite_generic_to_nested_op(stack, expr, add_mul)
elseif expr.args[1] == :-
Expand Down
20 changes: 20 additions & 0 deletions test/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,26 @@ function test_rewrite_kw_in_ref()
return
end

function test_rewrite_expression()
x = [1.2]
@test MA.@rewrite(x + 2 * x, move_factors_into_sums = false) == 3x
@test MA.@rewrite(x + *(2, x, 3), move_factors_into_sums = false) ==
x + *(2, x, 3)
y = 1.2
@test MA.@rewrite(
sum(y for i in 1:2) + 2y,
move_factors_into_sums = false
) == 4y
@test MA.@rewrite(
sum(y for i in 1:2) + y * y,
move_factors_into_sums = false
) == 2y + y^2
@test MA.@rewrite(y + 2 * y, move_factors_into_sums = false) == 3y
@test MA.@rewrite(y + *(2, y, 3), move_factors_into_sums = false) ==
y + *(2, y, 3)
return
end

end # module

TestRewriteGeneric.runtests()

0 comments on commit 774bc08

Please sign in to comment.