diff --git a/src/rewrite.jl b/src/rewrite.jl index 811570f..8cee94b 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -286,7 +286,11 @@ function _is_decomposable_with_factors(ex) end """ - rewrite(expr; move_factors_into_sums::Bool = true) -> Tuple{Symbol,Expr} + rewrite( + expr; + move_factors_into_sums::Bool = true, + return_is_mutable::Bool = false, + ) -> Tuple{Symbol,Expr[,Bool]} Rewrites the expression `expr` to use mutable arithmetics. @@ -318,34 +322,64 @@ variable = MA.operate!!(*, y, term) The latter can produce an additional allocation if there is an efficient fallback for `add_mul` and not for `*(y, term)`. + +## `return_is_mutable` + +If `return_is_mutable = true`, this function returns three arguments. The third +is a `Bool` indicating if the returned expression can be safely mutated without +changing the user's original expression. + +`return_is_mutable` cannot be `true` if `move_factors_into_sums = true`. """ -function rewrite(x; kwargs...) +function rewrite(x; return_is_mutable::Bool = false, kwargs...) variable = gensym() + if return_is_mutable + code, is_mutable = rewrite_and_return(x; return_is_mutable, kwargs...) + return variable, :($variable = $code), is_mutable + end code = rewrite_and_return(x; kwargs...) return variable, :($variable = $code) end """ - rewrite_and_return(expr; move_factors_into_sums::Bool = true) -> Expr + rewrite_and_return( + expr; + move_factors_into_sums::Bool = true, + return_is_mutable::Bool = false, + ) -> Expr Rewrite the expression `expr` using mutable arithmetics and return an expression in which the last statement is equivalent to `expr`. -See [`rewrite`](@ref) for an explanation of the keyword argument. +See [`rewrite`](@ref) for an explanation of the keyword arguments. """ -function rewrite_and_return(expr; move_factors_into_sums::Bool = true) +function rewrite_and_return( + expr; + move_factors_into_sums::Bool = true, + return_is_mutable::Bool = false, +) if move_factors_into_sums + @assert !return_is_mutable root, stack = _rewrite(false, false, expr, nothing, [], []) - else - stack = quote end - root, _ = _rewrite_generic(stack, expr) + return quote + let + $stack + $root + end + end end - return quote + stack = quote end + root, is_mutable = _rewrite_generic(stack, expr) + code = quote let $stack $root end end + if return_is_mutable + return code, is_mutable + end + return code end function _is_comparison(ex::Expr) diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index 05fe615..21763be 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -433,6 +433,38 @@ function test_rewrite_ifelse() return end +function test_return_is_mutable() + function _rewrite(expr) + return MA.rewrite( + expr; + move_factors_into_sums = false, + return_is_mutable = true, + ) + end + x, expr, is_mutable = _rewrite(1) + @test x isa Symbol + @test Meta.isexpr(expr, :(=), 2) + @test is_mutable + y = 1 + x, expr, is_mutable = _rewrite(:(y)) + @test x isa Symbol + @test Meta.isexpr(expr, :(=), 2) + @test !is_mutable + x, expr, is_mutable = _rewrite(:(1 + 1)) + @test x isa Symbol + @test Meta.isexpr(expr, :(=), 2) + @test is_mutable + @test_throws( + AssertionError, + MA.rewrite( + :(1 + 1); + move_factors_into_sums = true, + return_is_mutable = true, + ), + ) + return +end + end # module TestRewriteGeneric.runtests()