Skip to content

Commit

Permalink
Use let model=model in variable macro to improve type stability (#3500)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Sep 14, 2023
1 parent ca73f21 commit fa3d6b8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,17 @@ end

# TODO: update 3-argument @constraint macro to pass through names like @variable

function _wrap_let(model, code)
if Meta.isexpr(model, :escape) && model.args[1] isa Symbol
return quote
let $model = $model
$code
end
end
end
return code
end

"""
_constraint_macro(
args, macro_name::Symbol, parsefun::Function, source::LineNumberNode
Expand Down Expand Up @@ -1331,10 +1342,11 @@ function _constraint_macro(
$parsecode
$constraintcall
end

creation_code =
Containers.container_code(idxvars, indices, code, requested_container)

# Wrap the entire code block in a let statement to make the model act as
# a type stable local variable.
creation_code = _wrap_let(model, creation_code)
if anonvar
# Anonymous constraint, no need to register it in the model-level
# dictionary nor to assign it to a variable in the user scope.
Expand Down Expand Up @@ -1939,6 +1951,9 @@ macro expression(args...)
end
code =
Containers.container_code(idxvars, indices, code, requested_container)
# Wrap the entire code block in a let statement to make the model act as
# a type stable local variable.
code = _wrap_let(m, code)
# don't do anything with the model, but check that it's valid anyway
if anonvar
macro_code = code
Expand Down Expand Up @@ -3016,7 +3031,9 @@ macro variable(args...)
requested_container,
)
end

# Wrap the entire code block in a let statement to make the model act as
# a type stable local variable.
creation_code = _wrap_let(model, creation_code)
if anonvar
# Anonymous variable, no need to register it in the model-level
# dictionary nor to assign it to a variable in the user scope.
Expand Down
6 changes: 6 additions & 0 deletions test/test_hygiene.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,10 @@ Test.@test ex[3] == 6
Test.@test i == 10
Test.@test j == 10

# Test that `model` is inferred correctly inside macros, despite being a
# non-const global.
m = JuMP.Model()
JuMP.@variable(m, x[1:0])
Test.@test x == JuMP.VariableRef[]

end
49 changes: 49 additions & 0 deletions test/test_macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2034,4 +2034,53 @@ function test_unsupported_ternary_operator()
return
end

# This code needs to be evaluated in a top-level scope to prevent inference from
# knowing the type of `model`.
function test_wrap_let_non_symbol_models()
module_name = @eval module $(gensym())
using JuMP, Test
data = (; model = Model())
end
@eval module_name begin
@variable(data.model, x)
@test x isa VariableRef
@objective(data.model, Min, x^2)
@test isequal_canonical(objective_function(data.model), x^2)
@expression(data.model, expr[i = 1:2], x + i)
@test expr == [x + 1, x + 2]
@constraint(data.model, c[i = 1:2], i * expr[i] <= i)
@test c isa Vector{<:ConstraintRef}
@variable(data.model, bad_var[1:0])
@test bad_var isa Vector{<:Any}
@test !(bad_var isa Vector{VariableRef}) # Cannot prove type
@expression(data.model, bad_expr[i = 1:0], x + i)
@test bad_expr isa Vector{Any}
end
return
end

# This code needs to be evaluated in a top-level scope to prevent inference from
# knowing the type of `model`.
function test_wrap_let_symbol_models()
module_name = @eval module $(gensym())
using JuMP, Test
model = Model()
end
@eval module_name begin
@variable(model, x)
@test x isa VariableRef
@objective(model, Min, x^2)
@test isequal_canonical(objective_function(model), x^2)
@expression(model, expr[i = 1:2], x + i)
@test expr == [x + 1, x + 2]
@constraint(model, c[i = 1:2], i * expr[i] <= i)
@test c isa Vector{<:ConstraintRef}
@variable(model, bad_var[1:0])
@test bad_var isa Vector{VariableRef}
@expression(model, bad_expr[i = 1:0], x + i)
@test bad_expr isa Vector{Any}
end
return
end

end # module

0 comments on commit fa3d6b8

Please sign in to comment.