Skip to content

Commit

Permalink
Refactor GenericNonlinearExpr and NonlinearOperator constructors (#3489)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Sep 8, 2023
1 parent 3097f56 commit 0df25a9
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 90 deletions.
87 changes: 40 additions & 47 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,11 @@ struct GenericNonlinearExpr{V<:AbstractVariableRef} <: AbstractJuMPScalar
head::Symbol
args::Vector{Any}

function GenericNonlinearExpr(head::Symbol, args::Vector{Any})
index = findfirst(Base.Fix2(isa, AbstractJuMPScalar), args)
if index === nothing
error(
"Unable to create a nonlinear expression because it did not " *
"contain any JuMP scalars. head = $head, args = $args.",
)
end
return new{variable_ref_type(args[index])}(head, args)
function GenericNonlinearExpr{V}(
head::Symbol,
args::Vararg{Any},
) where {V<:AbstractVariableRef}
return new{V}(head, Any[a for a in args])
end

function GenericNonlinearExpr{V}(
Expand All @@ -100,6 +96,35 @@ struct GenericNonlinearExpr{V<:AbstractVariableRef} <: AbstractJuMPScalar
end
end

variable_ref_type(::Type{GenericNonlinearExpr}, ::Any) = nothing

function variable_ref_type(::Type{GenericNonlinearExpr}, x::AbstractJuMPScalar)
return variable_ref_type(x)
end

function _has_variable_ref_type(a)
return variable_ref_type(GenericNonlinearExpr, a) !== nothing
end

function _variable_ref_type(head, args)
if (i = findfirst(_has_variable_ref_type, args)) !== nothing
V = variable_ref_type(GenericNonlinearExpr, args[i])
return V::Type{<:AbstractVariableRef}
end
return error(
"Unable to create a nonlinear expression because it did not contain " *
"any JuMP scalars. head = `:$head`, args = `$args`.",
)
end

function GenericNonlinearExpr(head::Symbol, args::Vector{Any})
return GenericNonlinearExpr{_variable_ref_type(head, args)}(head, args)
end

function GenericNonlinearExpr(head::Symbol, args::Vararg{Any,N}) where {N}
return GenericNonlinearExpr{_variable_ref_type(head, args)}(head, args...)
end

"""
NonlinearExpr
Expand All @@ -110,15 +135,6 @@ const NonlinearExpr = GenericNonlinearExpr{VariableRef}

variable_ref_type(::GenericNonlinearExpr{V}) where {V} = V

# We include this method so that we can refactor the internal representation of
# GenericNonlinearExpr without having to rewrite the method overloads.
function GenericNonlinearExpr{V}(
head::Symbol,
args...,
) where {V<:AbstractVariableRef}
return GenericNonlinearExpr{V}(head, Any[args...])
end

const _PREFIX_OPERATORS =
(:+, :-, :*, :/, :^, :||, :&&, :>, :<, :(<=), :(>=), :(==))

Expand Down Expand Up @@ -527,6 +543,8 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
return ret
end

jump_function(::GenericModel{T}, x::Number) where {T} = convert(T, x)

function jump_function(model::GenericModel, f::MOI.ScalarNonlinearFunction)
V = variable_ref_type(typeof(model))
ret = GenericNonlinearExpr{V}(f.head, Any[])
Expand All @@ -542,8 +560,6 @@ function jump_function(model::GenericModel, f::MOI.ScalarNonlinearFunction)
for child in reverse(arg.args)
push!(stack, (new_ret, child))
end
elseif arg isa Number
push!(parent.args, arg)
else
push!(parent.args, jump_function(model, arg))
end
Expand Down Expand Up @@ -833,33 +849,10 @@ function Base.show(io::IO, f::NonlinearOperator)
return print(io, "NonlinearOperator($(f.func), :$(f.head))")
end

# Fast overload for unary calls

(f::NonlinearOperator)(x) = f.func(x)

(f::NonlinearOperator)(x::AbstractJuMPScalar) = NonlinearExpr(f.head, Any[x])

# Fast overload for binary calls

(f::NonlinearOperator)(x, y) = f.func(x, y)

function (f::NonlinearOperator)(x::AbstractJuMPScalar, y)
return GenericNonlinearExpr(f.head, Any[x, y])
end

function (f::NonlinearOperator)(x, y::AbstractJuMPScalar)
return GenericNonlinearExpr(f.head, Any[x, y])
end

function (f::NonlinearOperator)(x::AbstractJuMPScalar, y::AbstractJuMPScalar)
return GenericNonlinearExpr(f.head, Any[x, y])
end

# Fallback for more arguments
function (f::NonlinearOperator)(x, y, z...)
args = (x, y, z...)
if any(Base.Fix2(isa, AbstractJuMPScalar), args)
return GenericNonlinearExpr(f.head, Any[a for a in args])
function (f::NonlinearOperator)(args::Vararg{Any,N}) where {N}
types = variable_ref_type.(GenericNonlinearExpr, args)
if (i = findfirst(!isnothing, types)) !== nothing
return GenericNonlinearExpr{types[i]}(f.head, args...)
end
return f.func(args...)
end
Expand Down
11 changes: 5 additions & 6 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,28 +195,27 @@ function Base.:/(lhs::GenericAffExpr, rhs::_Constant)
return map_coefficients(c -> c / rhs, lhs)
end

function Base.:^(lhs::AbstractVariableRef, rhs::Integer)
T = value_type(typeof(lhs))
function Base.:^(lhs::V, rhs::Integer) where {V<:AbstractVariableRef}
if rhs == 0
return one(T)
return one(value_type(V))
elseif rhs == 1
return lhs
elseif rhs == 2
return lhs * lhs
else
return GenericNonlinearExpr(:^, Any[lhs, rhs])
return GenericNonlinearExpr{V}(:^, Any[lhs, rhs])
end
end

function Base.:^(lhs::GenericAffExpr{T}, rhs::Integer) where {T}
function Base.:^(lhs::GenericAffExpr{T,V}, rhs::Integer) where {T,V}
if rhs == 0
return one(T)
elseif rhs == 1
return lhs
elseif rhs == 2
return lhs * lhs
else
return GenericNonlinearExpr(:^, Any[lhs, rhs])
return GenericNonlinearExpr{V}(:^, Any[lhs, rhs])
end
end

Expand Down
12 changes: 6 additions & 6 deletions test/test_nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ function test_parse_expression_nonlinearexpr_call()
model = Model()
@variable(model, x)
@variable(model, y)
f = GenericNonlinearExpr(:ifelse, Any[x, 0, y])
f = NonlinearExpr(:ifelse, Any[x, 0, y])
@NLexpression(model, ref, f)
nlp = nonlinear_model(model)
expr = :(ifelse($x, 0, $y))
Expand All @@ -1617,7 +1617,7 @@ function test_parse_expression_nonlinearexpr_or()
model = Model()
@variable(model, x)
@variable(model, y)
f = GenericNonlinearExpr(:||, Any[x, y])
f = NonlinearExpr(:||, Any[x, y])
@NLexpression(model, ref, f)
nlp = nonlinear_model(model)
expr = :($x || $y)
Expand All @@ -1629,7 +1629,7 @@ function test_parse_expression_nonlinearexpr_and()
model = Model()
@variable(model, x)
@variable(model, y)
f = GenericNonlinearExpr(:&&, Any[x, y])
f = NonlinearExpr(:&&, Any[x, y])
@NLexpression(model, ref, f)
nlp = nonlinear_model(model)
expr = :($x && $y)
Expand All @@ -1641,7 +1641,7 @@ function test_parse_expression_nonlinearexpr_unsupported()
model = Model()
@variable(model, x)
@variable(model, y)
f = GenericNonlinearExpr(:foo, Any[x, y])
f = NonlinearExpr(:foo, Any[x, y])
@test_throws(
MOI.UnsupportedNonlinearOperator,
@NLexpression(model, ref, f),
Expand All @@ -1653,8 +1653,8 @@ function test_parse_expression_nonlinearexpr_nested_comparison()
model = Model()
@variable(model, x)
@variable(model, y)
f = GenericNonlinearExpr(:||, Any[x, y])
g = GenericNonlinearExpr(:&&, Any[f, x])
f = NonlinearExpr(:||, Any[x, y])
g = NonlinearExpr(:&&, Any[f, x])
@NLexpression(model, ref, g)
nlp = nonlinear_model(model)
expr = :(($x || $y) && $x)
Expand Down
Loading

0 comments on commit 0df25a9

Please sign in to comment.