From 461955ed2b1d7351a88ee93fa743e0182227c741 Mon Sep 17 00:00:00 2001 From: odow Date: Tue, 22 Nov 2022 14:10:29 +1300 Subject: [PATCH] Support splatting in rewrite_generic.jl --- src/rewrite_generic.jl | 3 +++ test/rewrite_generic.jl | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 69f0151f..f1acd152 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -65,6 +65,9 @@ function _rewrite_generic(stack::Expr, expr::Expr) elseif Meta.isexpr(expr, :call, 1) # A zero-argument function return esc(expr), false + elseif Meta.isexpr(expr.args[2], :(...)) + # If the first argument is a splat. + return esc(expr), false elseif _is_generator(expr) || _is_flatten(expr) || _is_parameters(expr) if !(expr.args[1] in (:sum, :Σ, :∑)) # We don't know what this is. Return the expression and don't let diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index a1c1aaf7..06532713 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -266,6 +266,24 @@ function test_no_product() return end +function test_splatting() + x = [1, 2, 3] + @test MA.@rewrite(+(x...), move_factors_into_sums = false) == 6 + @test MA.@rewrite(+(4, x...), move_factors_into_sums = false) == 10 + @test MA.@rewrite(+(x..., 4), move_factors_into_sums = false) == 10 + @test MA.@rewrite(+(4, x..., 5), move_factors_into_sums = false) == 15 + @test MA.@rewrite(*(x...), move_factors_into_sums = false) == 6 + @test MA.@rewrite(*(4, x...), move_factors_into_sums = false) == 24 + @test MA.@rewrite(*(x..., 4), move_factors_into_sums = false) == 24 + @test MA.@rewrite(*(4, x..., 5), move_factors_into_sums = false) == 120 + @test MA.@rewrite( + +(4, x..., *(4, x..., 5)), + move_factors_into_sums = false, + ) == 130 + @test MA.@rewrite(vcat(x...), move_factors_into_sums = false) == x + return +end + end # module TestRewriteGeneric.runtests()