From fa3d6b8d42005f268a6c2d878306752684bd9acf Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Thu, 14 Sep 2023 19:01:38 +1200 Subject: [PATCH] Use let model=model in variable macro to improve type stability (#3500) --- src/macros.jl | 23 ++++++++++++++++++--- test/test_hygiene.jl | 6 ++++++ test/test_macros.jl | 49 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/src/macros.jl b/src/macros.jl index 7660ab6576d..afbd777b460 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -1204,6 +1204,17 @@ end # TODO: update 3-argument @constraint macro to pass through names like @variable +function _wrap_let(model, code) + if Meta.isexpr(model, :escape) && model.args[1] isa Symbol + return quote + let $model = $model + $code + end + end + end + return code +end + """ _constraint_macro( args, macro_name::Symbol, parsefun::Function, source::LineNumberNode @@ -1331,10 +1342,11 @@ function _constraint_macro( $parsecode $constraintcall end - creation_code = Containers.container_code(idxvars, indices, code, requested_container) - + # Wrap the entire code block in a let statement to make the model act as + # a type stable local variable. + creation_code = _wrap_let(model, creation_code) if anonvar # Anonymous constraint, no need to register it in the model-level # dictionary nor to assign it to a variable in the user scope. @@ -1939,6 +1951,9 @@ macro expression(args...) end code = Containers.container_code(idxvars, indices, code, requested_container) + # Wrap the entire code block in a let statement to make the model act as + # a type stable local variable. + code = _wrap_let(m, code) # don't do anything with the model, but check that it's valid anyway if anonvar macro_code = code @@ -3016,7 +3031,9 @@ macro variable(args...) requested_container, ) end - + # Wrap the entire code block in a let statement to make the model act as + # a type stable local variable. + creation_code = _wrap_let(model, creation_code) if anonvar # Anonymous variable, no need to register it in the model-level # dictionary nor to assign it to a variable in the user scope. diff --git a/test/test_hygiene.jl b/test/test_hygiene.jl index a7bc87a8ed5..9ff08bb6edc 100644 --- a/test/test_hygiene.jl +++ b/test/test_hygiene.jl @@ -57,4 +57,10 @@ Test.@test ex[3] == 6 Test.@test i == 10 Test.@test j == 10 +# Test that `model` is inferred correctly inside macros, despite being a +# non-const global. +m = JuMP.Model() +JuMP.@variable(m, x[1:0]) +Test.@test x == JuMP.VariableRef[] + end diff --git a/test/test_macros.jl b/test/test_macros.jl index 3c58375340b..340a69f73df 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -2034,4 +2034,53 @@ function test_unsupported_ternary_operator() return end +# This code needs to be evaluated in a top-level scope to prevent inference from +# knowing the type of `model`. +function test_wrap_let_non_symbol_models() + module_name = @eval module $(gensym()) + using JuMP, Test + data = (; model = Model()) + end + @eval module_name begin + @variable(data.model, x) + @test x isa VariableRef + @objective(data.model, Min, x^2) + @test isequal_canonical(objective_function(data.model), x^2) + @expression(data.model, expr[i = 1:2], x + i) + @test expr == [x + 1, x + 2] + @constraint(data.model, c[i = 1:2], i * expr[i] <= i) + @test c isa Vector{<:ConstraintRef} + @variable(data.model, bad_var[1:0]) + @test bad_var isa Vector{<:Any} + @test !(bad_var isa Vector{VariableRef}) # Cannot prove type + @expression(data.model, bad_expr[i = 1:0], x + i) + @test bad_expr isa Vector{Any} + end + return +end + +# This code needs to be evaluated in a top-level scope to prevent inference from +# knowing the type of `model`. +function test_wrap_let_symbol_models() + module_name = @eval module $(gensym()) + using JuMP, Test + model = Model() + end + @eval module_name begin + @variable(model, x) + @test x isa VariableRef + @objective(model, Min, x^2) + @test isequal_canonical(objective_function(model), x^2) + @expression(model, expr[i = 1:2], x + i) + @test expr == [x + 1, x + 2] + @constraint(model, c[i = 1:2], i * expr[i] <= i) + @test c isa Vector{<:ConstraintRef} + @variable(model, bad_var[1:0]) + @test bad_var isa Vector{VariableRef} + @expression(model, bad_expr[i = 1:0], x + i) + @test bad_expr isa Vector{Any} + end + return +end + end # module