Skip to content

Commit

Permalink
Implement performance optimization of promote_operation for *(::Any, …
Browse files Browse the repository at this point in the history
…::Zero) (#284)
  • Loading branch information
odow authored Apr 25, 2024
1 parent a6ed0f5 commit 7344c94
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ function Base.:/(z::Zero, x::Any)
end
end

# These methods are used to provide an efficient implementation for the common
# case like `x^2 * sum(f for i in 1:0)`, which lowers to
# `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed
# arguments because MA.Zero is not mutable, and MA never queries the mutablility
# of arguments if the first is not mutable.
promote_operation(::typeof(*), ::Type{<:Any}, ::Type{Zero}) = Zero

function promote_operation(
::typeof(*),
::Type{<:AbstractArray{T}},
::Type{Zero},
) where {T}
return Zero
end

# Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)`
# since we don't require mutable type to support Zero in
# `mutable_operate!`.
Expand Down
67 changes: 67 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,70 @@ end
b = @allocated MA.operate(LinearAlgebra.dot, x, y)
@test a == b
end

@testset "test_multiply_expr_MA_Zero" begin
x = DummyBigInt(1)
f = DummyBigInt(2)
@test MA.@rewrite(
f * sum(x for i in 1:0),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
sum(x for i in 1:0) * f,
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
-f * sum(x for i in 1:0),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
sum(x for i in 1:0) * -f,
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
(f + f) * sum(x for i in 1:0),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
sum(x for i in 1:0) * (f + f),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
-[f] * sum(x for i in 1:0),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
sum(x for i in 1:0) * -[f],
move_factors_into_sums = false
) == MA.Zero()
@test MA.isequal_canonical(
MA.@rewrite(f + sum(x for i in 1:0), move_factors_into_sums = false),
f,
)
@test MA.isequal_canonical(
MA.@rewrite(sum(x for i in 1:0) + f, move_factors_into_sums = false),
f,
)
@test MA.isequal_canonical(
MA.@rewrite(-f + sum(x for i in 1:0), move_factors_into_sums = false),
-f,
)
@test MA.isequal_canonical(
MA.@rewrite(sum(x for i in 1:0) + -f, move_factors_into_sums = false),
-f,
)
@test MA.isequal_canonical(
MA.@rewrite(
(f + f) + sum(x for i in 1:0),
move_factors_into_sums = false
),
f + f,
)
@test MA.isequal_canonical(
MA.@rewrite(
sum(x for i in 1:0) + (f + f),
move_factors_into_sums = false
),
f + f,
)
end

0 comments on commit 7344c94

Please sign in to comment.