Skip to content

Commit

Permalink
Support splatting in rewrite_generic.jl (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 22, 2022
1 parent c715a0c commit dfebb83
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ function _rewrite_generic(stack::Expr, expr::Expr)
elseif Meta.isexpr(expr, :call, 1)
# A zero-argument function
return esc(expr), false
elseif Meta.isexpr(expr.args[2], :(...))
# If the first argument is a splat.
return esc(expr), false
elseif _is_generator(expr) || _is_flatten(expr) || _is_parameters(expr)
if !(expr.args[1] in (:sum, , :∑))
# We don't know what this is. Return the expression and don't let
Expand Down
18 changes: 18 additions & 0 deletions test/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,24 @@ function test_no_product()
return
end

function test_splatting()
x = [1, 2, 3]
@test MA.@rewrite(+(x...), move_factors_into_sums = false) == 6
@test MA.@rewrite(+(4, x...), move_factors_into_sums = false) == 10
@test MA.@rewrite(+(x..., 4), move_factors_into_sums = false) == 10
@test MA.@rewrite(+(4, x..., 5), move_factors_into_sums = false) == 15
@test MA.@rewrite(*(x...), move_factors_into_sums = false) == 6
@test MA.@rewrite(*(4, x...), move_factors_into_sums = false) == 24
@test MA.@rewrite(*(x..., 4), move_factors_into_sums = false) == 24
@test MA.@rewrite(*(4, x..., 5), move_factors_into_sums = false) == 120
@test MA.@rewrite(
+(4, x..., *(4, x..., 5)),
move_factors_into_sums = false,
) == 130
@test MA.@rewrite(vcat(x...), move_factors_into_sums = false) == x
return
end

end # module

TestRewriteGeneric.runtests()

0 comments on commit dfebb83

Please sign in to comment.