Skip to content

Commit

Permalink
Add return_is_mutable kwarg to MA.rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Nov 14, 2024
1 parent d47000d commit d075a09
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
52 changes: 43 additions & 9 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ function _is_decomposable_with_factors(ex)
end

"""
rewrite(expr; move_factors_into_sums::Bool = true) -> Tuple{Symbol,Expr}
rewrite(
expr;
move_factors_into_sums::Bool = true,
return_is_mutable::Bool = false,
) -> Tuple{Symbol,Expr[,Bool]}
Rewrites the expression `expr` to use mutable arithmetics.
Expand Down Expand Up @@ -318,34 +322,64 @@ variable = MA.operate!!(*, y, term)
The latter can produce an additional allocation if there is an efficient
fallback for `add_mul` and not for `*(y, term)`.
## `return_is_mutable`
If `return_is_mutable = true`, this function returns three arguments. The third
is a `Bool` indicating if the returned expression can be safely mutated without
changing the user's original expression.
`return_is_mutable` cannot be `true` if `move_factors_into_sums = true`.
"""
function rewrite(x; kwargs...)
function rewrite(x; return_is_mutable::Bool = false, kwargs...)
variable = gensym()
if return_is_mutable
code, is_mutable = rewrite_and_return(x; return_is_mutable, kwargs...)
return variable, :($variable = $code), is_mutable
end
code = rewrite_and_return(x; kwargs...)
return variable, :($variable = $code)
end

"""
rewrite_and_return(expr; move_factors_into_sums::Bool = true) -> Expr
rewrite_and_return(
expr;
move_factors_into_sums::Bool = true,
return_is_mutable::Bool = false,
) -> Expr
Rewrite the expression `expr` using mutable arithmetics and return an expression
in which the last statement is equivalent to `expr`.
See [`rewrite`](@ref) for an explanation of the keyword argument.
See [`rewrite`](@ref) for an explanation of the keyword arguments.
"""
function rewrite_and_return(expr; move_factors_into_sums::Bool = true)
function rewrite_and_return(
expr;
move_factors_into_sums::Bool = true,
return_is_mutable::Bool = false,
)
if move_factors_into_sums
@assert !return_is_mutable
root, stack = _rewrite(false, false, expr, nothing, [], [])
else
stack = quote end
root, _ = _rewrite_generic(stack, expr)
return quote
let
$stack
$root
end
end
end
return quote
stack = quote end
root, is_mutable = _rewrite_generic(stack, expr)
code = quote
let
$stack
$root
end
end
if return_is_mutable
return code, is_mutable
end
return code
end

function _is_comparison(ex::Expr)
Expand Down
32 changes: 32 additions & 0 deletions test/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,38 @@ function test_rewrite_ifelse()
return
end

function test_return_is_mutable()
function _rewrite(expr)
return MA.rewrite(
expr;
move_factors_into_sums = false,
return_is_mutable = true,
)
end
x, expr, is_mutable = _rewrite(1)
@test x isa Symbol
@test Meta.isexpr(expr, :(=), 2)
@test is_mutable
y = 1
x, expr, is_mutable = _rewrite(:(y))
@test x isa Symbol
@test Meta.isexpr(expr, :(=), 2)
@test !is_mutable
x, expr, is_mutable = _rewrite(:(1 + 1))
@test x isa Symbol
@test Meta.isexpr(expr, :(=), 2)
@test is_mutable
@test_throws(
AssertionError,
MA.rewrite(
:(1 + 1);
move_factors_into_sums = true,
return_is_mutable = true,
),
)
return
end

end # module

TestRewriteGeneric.runtests()

0 comments on commit d075a09

Please sign in to comment.