Skip to content

Commit

Permalink
Fix MA.add_mul for GenericNonlinearExpr arguments (#3488)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Sep 6, 2023
1 parent f496535 commit a157126
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/mutable_arithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,21 @@ function _MA.sub_mul(
end
return _MA.operate!(_MA.sub_mul, expr, x, y, args...)
end

# The default implementation of `add_mul` falls back to Base.muladd, which ends
# up rewriting `x + y * z` as `y * z + x`. Intercept some methods, ensuring that
# we don't cause ambiguities.

for F in (:_Scalar, :AbstractJuMPScalar)
@eval begin
_MA.add_mul(x::$F, y::GenericNonlinearExpr, z::_Scalar) = x + y * z
_MA.add_mul(x::$F, y::_Scalar, z::GenericNonlinearExpr) = x + y * z
function _MA.add_mul(
x::$F,
y::GenericNonlinearExpr,
z::GenericNonlinearExpr,
)
return x + y * z
end
end
end
16 changes: 16 additions & 0 deletions test/test_nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,22 @@ function test_extension_expression_addmul(
return
end

function test_extension_expression_explicit_add_mul(
ModelType = Model,
VariableRefType = VariableRef,
)
model = ModelType()
@variable(model, x)
f = sin(x)
@test string(MA.operate!!(MA.add_mul, 1, 2, f)) == "1.0 + (2.0 * $f)"
@test string(MA.operate!!(MA.add_mul, 1, f, 2)) == "1.0 + ($f * 2.0)"
@test string(MA.operate!!(MA.add_mul, 1, f, f)) == "1.0 + ($f * $f)"
@test string(MA.operate!!(MA.add_mul, f, 2, f)) == "$f + (2.0 * $f)"
@test string(MA.operate!!(MA.add_mul, f, f, 2)) == "$f + ($f * 2.0)"
@test string(MA.operate!!(MA.add_mul, f, f, f)) == "$f + ($f * $f)"
return
end

function test_extension_expression_submul(
ModelType = Model,
VariableRefType = VariableRef,
Expand Down

0 comments on commit a157126

Please sign in to comment.