From 406f5b44725287f6a2211eb7369fe67ae089891a Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Tue, 12 Dec 2023 20:17:25 -0800 Subject: [PATCH] Make EnterNode save/restore dynamic scope (#52309) As discussed in #51352, this gives `EnterNode` the ability to set (and restore on leave or catch edge) jl_current_task->scope. Manual modifications of the task field after the task has started are considered undefined behavior. In addition, we gain a new intrinsic to access current_task->scope and both inference and the optimizer will forward scopes from EnterNodes to this intrinsic (non-interprocedurally). Together with #51993 this is sufficient to fully optimize ScopedValues (non-interprocedurally at least). --- base/boot.jl | 4 +- base/compiler/abstractinterpretation.jl | 13 +++++++ base/compiler/inferencestate.jl | 6 ++- base/compiler/ssair/ir.jl | 1 + base/compiler/ssair/passes.jl | 25 ++++++++++++ base/compiler/ssair/show.jl | 4 ++ base/compiler/ssair/verify.jl | 2 +- base/compiler/tfuncs.jl | 52 +++++++++++++++++++++++-- base/compiler/validation.jl | 9 ++++- base/scopedvalues.jl | 51 ++++++------------------ base/task.jl | 9 +++++ src/ast.c | 4 ++ src/builtin_proto.h | 1 + src/builtins.c | 7 ++++ src/codegen.cpp | 52 ++++++++++++++++++++++--- src/interpreter.c | 19 ++++++++- src/jltypes.c | 4 +- src/julia-syntax.scm | 13 +++++-- src/julia.h | 1 + src/method.c | 15 +++++++ src/staticdata.c | 1 + test/core.jl | 13 +++++++ test/scopedvalues.jl | 4 +- 23 files changed, 248 insertions(+), 62 deletions(-) diff --git a/base/boot.jl b/base/boot.jl index 218e4e2e533b1..13bb6bcd7cd4b 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -460,6 +460,7 @@ eval(Core, quote ReturnNode() = $(Expr(:new, :ReturnNode)) # unassigned val indicates unreachable GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest)) EnterNode(dest::Int) = $(Expr(:new, :EnterNode, :dest)) + EnterNode(dest::Int, @nospecialize(scope)) = $(Expr(:new, :EnterNode, :dest, :scope)) LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing)) function LineNumberNode(l::Int, @nospecialize(f)) isa(f, String) && (f = Symbol(f)) @@ -966,7 +967,8 @@ arraysize(a::Array, i::Int) = sle_int(i, nfields(a.size)) ? getfield(a.size, i) export arrayref, arrayset, arraysize, const_arrayref # For convenience -EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest) +EnterNode(old::EnterNode, new_dest::Int) = isdefined(old, :scope) ? + EnterNode(new_dest, old.scope) : EnterNode(new_dest) include(Core, "optimized_generics.jl") diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index dd156ae338058..0541aa9e64278 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -3270,6 +3270,19 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) elseif isa(stmt, EnterNode) ssavaluetypes[currpc] = Any add_curr_ssaflag!(frame, IR_FLAG_NOTHROW) + if isdefined(stmt, :scope) + scopet = abstract_eval_value(interp, stmt.scope, currstate, frame) + handler = frame.handlers[frame.handler_at[frame.currpc+1][1]] + @assert handler.scopet !== nothing + if !โŠ‘(๐•ƒแตข, scopet, handler.scopet) + handler.scopet = tmerge(๐•ƒแตข, scopet, handler.scopet) + if isdefined(handler, :scope_uses) + for bb in handler.scope_uses + push!(W, bb) + end + end + end + end @goto fallthrough elseif isexpr(stmt, :leave) ssavaluetypes[currpc] = Any diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 09757adc54850..f8650c2e6cfde 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -205,8 +205,10 @@ const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed mutable struct TryCatchFrame exct + scopet const enter_idx::Int - TryCatchFrame(@nospecialize(exct), enter_idx::Int) = new(exct, enter_idx) + scope_uses::Vector{Int} + TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx) end mutable struct InferenceState @@ -364,7 +366,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet) stmt = code[pc] if isa(stmt, EnterNode) l = stmt.catch_dest - push!(handlers, TryCatchFrame(Bottom, pc)) + push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc)) handler_id = length(handlers) handler_at[pc + 1] = (handler_id, 0) push!(ip, pc + 1) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 6a1114a009519..c37562a06c2ca 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -1414,6 +1414,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr result_idx += 1 end elseif cfg_transforms_enabled && isa(stmt, EnterNode) + stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa, mark_refined!)::EnterNode label = bb_rename_succ[stmt.catch_dest] @assert label > 0 ssa_rename[idx] = SSAValue(result_idx) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 5abed1a182048..fdd7c0d3b48aa 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1069,6 +1069,29 @@ function fold_ifelse!(compact::IncrementalCompact, idx::Int, stmt::Expr) return false end +function fold_current_scope!(compact::IncrementalCompact, idx::Int, stmt::Expr, lazydomtree::LazyDomtree) + domtree = get!(lazydomtree) + + # The frontend enforces the invariant that any :enter dominates its active + # region, so all we have to do here is walk the domtree to find it. + dombb = block_for_inst(compact, SSAValue(idx)) + + local bbterminator + while true + dombb = domtree.idoms_bb[dombb] + + # Did not find any dominating :enter - scope is inherited from the outside + dombb == 0 && return nothing + + bbterminator = compact[SSAValue(last(compact.cfg_transform.result_bbs[dombb].stmts))][:stmt] + isa(bbterminator, EnterNode) || continue + isdefined(bbterminator, :scope) || continue + compact[idx] = bbterminator.scope + return nothing + end +end + + # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, # which can be very large sometimes, and program counters in question are often very sparse const SPCSet = IdSet{Int} @@ -1201,6 +1224,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) elseif is_known_invoke_or_call(stmt, Core.OptimizedGenerics.KeyValue.get, compact) 2 == (length(stmt.args) - (isexpr(stmt, :invoke) ? 2 : 1)) || continue lift_keyvalue_get!(compact, idx, stmt, ๐•ƒโ‚’) + elseif is_known_call(stmt, Core.current_scope, compact) + fold_current_scope!(compact, idx, stmt, lazydomtree) elseif isexpr(stmt, :new) refine_new_effects!(๐•ƒโ‚’, compact, idx, stmt) end diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index 7323c21ae56b4..82af7f9964045 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -69,6 +69,10 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxleng # given control flow information, we prefer to print these with the basic block #, instead of the ssa % elseif isa(stmt, EnterNode) print(io, "enter #", stmt.catch_dest, "") + if isdefined(stmt, :scope) + print(io, " with scope ") + show_unquoted(io, stmt.scope, indent) + end elseif stmt isa GotoNode print(io, "goto #", stmt.label) elseif stmt isa PhiNode diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index 03e3ab0b0f03a..9eded81d9d84b 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -2,7 +2,7 @@ function maybe_show_ir(ir::IRCode) if isdefined(Core, :Main) - Core.Main.Base.display(ir) + invokelatest(Core.Main.Base.display, ir) end end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index c30f4a1a237e1..bb8712458e0be 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -2488,6 +2488,19 @@ function builtin_effects(๐•ƒ::AbstractLattice, @nospecialize(f::Builtin), argty return Effects(EFFECTS_TOTAL; consistent = (isa(setting, Const) && setting.val === :conditional) ? ALWAYS_TRUE : ALWAYS_FALSE, nothrow = compilerbarrier_nothrow(setting, nothing)) + elseif f === Core.current_scope + nothrow = true + if length(argtypes) != 0 + if length(argtypes) != 1 || !isvarargtype(argtypes[1]) + return EFFECTS_THROWS + end + nothrow = false + end + return Effects(EFFECTS_TOTAL; + consistent = ALWAYS_FALSE, + notaskstate = false, + nothrow + ) else if contains_is(_CONSISTENT_BUILTINS, f) consistent = ALWAYS_TRUE @@ -2554,6 +2567,32 @@ function memoryop_noub(@nospecialize(f), argtypes::Vector{Any}) return false end +function current_scope_tfunc(interp::AbstractInterpreter, sv::InferenceState) + pc = sv.currpc + while true + handleridx = sv.handler_at[pc][1] + if handleridx == 0 + # No local scope available - inherited from the outside + return Any + end + pchandler = sv.handlers[handleridx] + # Remember that we looked at this handler, so we get re-scheduled + # if the scope information changes + isdefined(pchandler, :scope_uses) || (pchandler.scope_uses = Int[]) + pcbb = block_for_inst(sv.cfg, pc) + if findfirst(==(pcbb), pchandler.scope_uses) === nothing + push!(pchandler.scope_uses, pcbb) + end + scope = pchandler.scopet + if scope !== nothing + # Found the scope - forward it + return scope + end + pc = pchandler.enter_idx + end +end +current_scope_tfunc(interp::AbstractInterpreter, sv) = Any + """ builtin_nothrow(๐•ƒ::AbstractLattice, f::Builtin, argtypes::Vector{Any}, rt) -> Bool @@ -2568,9 +2607,6 @@ end function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, sv::Union{AbsIntState, Nothing}) ๐•ƒแตข = typeinf_lattice(interp) - if f === tuple - return tuple_tfunc(๐•ƒแตข, argtypes) - end if isa(f, IntrinsicFunction) if is_pure_intrinsic_infer(f) && all(@nospecialize(a) -> isa(a, Const), argtypes) argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes) @@ -2596,6 +2632,16 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp end tf = T_IFUNC[iidx] else + if f === tuple + return tuple_tfunc(๐•ƒแตข, argtypes) + elseif f === Core.current_scope + if length(argtypes) != 0 + if length(argtypes) != 1 || !isvarargtype(argtypes[1]) + return Bottom + end + end + return current_scope_tfunc(interp, sv) + end fidx = find_tfunc(f) if fidx === nothing # unknown/unhandled builtin function diff --git a/base/compiler/validation.jl b/base/compiler/validation.jl index ef6602b082797..2428ea8a38892 100644 --- a/base/compiler/validation.jl +++ b/base/compiler/validation.jl @@ -13,7 +13,7 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}( :new => 1:typemax(Int), :splatnew => 2:2, :the_exception => 0:0, - :enter => 1:1, + :enter => 1:2, :leave => 1:typemax(Int), :pop_exception => 1:1, :inbounds => 1:1, @@ -160,6 +160,13 @@ function validate_code!(errors::Vector{InvalidCodeError}, c::CodeInfo, is_top_le push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.cond)) end validate_val!(x.cond) + elseif isa(x, EnterNode) + if isdefined(x, :scope) + if !is_valid_argument(x.scope) + push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.scope)) + end + validate_val!(x.scope) + end elseif isa(x, ReturnNode) if isdefined(x, :val) if !is_valid_return(x.val) diff --git a/base/scopedvalues.jl b/base/scopedvalues.jl index fd18e932d6a46..6eb1004a1d30f 100644 --- a/base/scopedvalues.jl +++ b/base/scopedvalues.jl @@ -81,13 +81,6 @@ function Scope(scope, pair1::Pair{<:ScopedValue}, pair2::Pair{<:ScopedValue}, pa end Scope(::Nothing) = nothing -""" - current_scope()::Union{Nothing, Scope} - -Return the current dynamic scope. -""" -current_scope() = current_task().scope::Union{Nothing, Scope} - function Base.show(io::IO, scope::Scope) print(io, Scope, "(") first = true @@ -116,8 +109,7 @@ return `nothing`. Otherwise returns `Some{T}` with the current value. """ function get(val::ScopedValue{T}) where {T} - # Inline current_scope to avoid doing the type assertion twice. - scope = current_task().scope + scope = Core.current_scope()::Union{Scope, Nothing} if scope === nothing isassigned(val) && return Some{T}(val.default) return nothing @@ -151,25 +143,6 @@ function Base.show(io::IO, val::ScopedValue) print(io, ')') end -""" - with(f, (var::ScopedValue{T} => val::T)...) - -Execute `f` in a new scope with `var` set to `val`. -""" -function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...) - @nospecialize - ct = Base.current_task() - current_scope = ct.scope::Union{Nothing, Scope} - ct.scope = Scope(current_scope, pair, rest...) - try - return f() - finally - ct.scope = current_scope - end -end - -with(@nospecialize(f)) = f() - """ @with vars... expr @@ -187,18 +160,18 @@ macro with(exprs...) else error("@with expects at least one argument") end - for expr in exprs - if expr.head !== :call || first(expr.args) !== :(=>) - error("@with expects arguments of the form `A => 2` got $expr") - end - end exprs = map(esc, exprs) - quote - ct = $(Base.current_task)() - current_scope = ct.scope::$(Union{Nothing, Scope}) - ct.scope = $(Scope)(current_scope, $(exprs...)) - $(Expr(:tryfinally, esc(ex), :(ct.scope = current_scope))) - end + Expr(:tryfinally, esc(ex), :(), :(Scope(Core.current_scope()::Union{Nothing, Scope}, $(exprs...)))) end +""" + with(f, (var::ScopedValue{T} => val::T)...) + +Execute `f` in a new scope with `var` set to `val`. +""" +function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...) + @with(pair, rest..., f()) +end +with(@nospecialize(f)) = f() + end # module ScopedValues diff --git a/base/task.jl b/base/task.jl index 09b40f19f5913..79e85024ac0d0 100644 --- a/base/task.jl +++ b/base/task.jl @@ -180,11 +180,20 @@ end elseif field === :exception # TODO: this field name should be deprecated in 2.0 return t._isexception ? t.result : nothing + elseif field === :scope + error("Querying `scope` is disallowed. Use `current_scope` instead.") else return getfield(t, field) end end +@inline function setproperty!(t::Task, field::Symbol, @nospecialize(v)) + if field === :scope + istaskstarted(t) && error("Setting scope on a started task directly is disallowed.") + end + return @invoke setproperty!(t::Any, field, v) +end + """ istaskdone(t::Task) -> Bool diff --git a/src/ast.c b/src/ast.c index 8257309a07b2b..7e7e7fb445e00 100644 --- a/src/ast.c +++ b/src/ast.c @@ -606,7 +606,11 @@ static jl_value_t *scm_to_julia_(fl_context_t *fl_ctx, value_t e, jl_module_t *m else if (sym == jl_enter_sym) { ex = scm_to_julia_(fl_ctx, car_(e), mod); temp = jl_new_struct_uninit(jl_enternode_type); + jl_enternode_scope(temp) = NULL; jl_enternode_catch_dest(temp) = jl_unbox_long(ex); + if (n == 2) { + jl_enternode_scope(temp) = scm_to_julia(fl_ctx, car_(cdr_(e)), mod); + } } else if (sym == jl_newvar_sym) { ex = scm_to_julia_(fl_ctx, car_(e), mod); diff --git a/src/builtin_proto.h b/src/builtin_proto.h index e6e91207b2fdb..a009b535ac951 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -62,6 +62,7 @@ DECLARE_BUILTIN(setglobal); DECLARE_BUILTIN(finalizer); DECLARE_BUILTIN(_compute_sparams); DECLARE_BUILTIN(_svec_ref); +DECLARE_BUILTIN(current_scope); JL_CALLABLE(jl_f__structtype); JL_CALLABLE(jl_f__abstracttype); diff --git a/src/builtins.c b/src/builtins.c index 5238a9476c9a2..e6472457cb6a9 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -587,6 +587,12 @@ JL_CALLABLE(jl_f_ifelse) return (args[0] == jl_false ? args[2] : args[1]); } +JL_CALLABLE(jl_f_current_scope) +{ + JL_NARGS(current_scope, 0, 0); + return jl_current_task->scope; +} + // apply ---------------------------------------------------------------------- static NOINLINE jl_svec_t *_copy_to(size_t newalloc, jl_value_t **oldargs, size_t oldalloc) @@ -2158,6 +2164,7 @@ void jl_init_primitives(void) JL_GC_DISABLED add_builtin_func("finalizer", jl_f_finalizer); add_builtin_func("_compute_sparams", jl_f__compute_sparams); add_builtin_func("_svec_ref", jl_f__svec_ref); + add_builtin_func("current_scope", jl_f_current_scope); // builtin types add_builtin("Any", (jl_value_t*)jl_any_type); diff --git a/src/codegen.cpp b/src/codegen.cpp index a5615cbdb013d..2cbe7733ab708 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1833,6 +1833,7 @@ class jl_codectx_t { // local var info. globals are not in here. SmallVector slots; std::map phic_slots; + std::map > scope_restore; SmallVector SAvalues; SmallVector, 0> PhiNodes; SmallVector ssavalue_assigned; @@ -5601,18 +5602,28 @@ static void emit_stmtpos(jl_codectx_t &ctx, jl_value_t *expr, int ssaval_result) } else if (head == jl_leave_sym) { int hand_n_leave = 0; + Value *scope_to_restore = nullptr; + Value *scope_ptr = nullptr; for (size_t i = 0; i < jl_expr_nargs(ex); ++i) { jl_value_t *arg = args[i]; if (arg == jl_nothing) continue; assert(jl_is_ssavalue(arg)); - jl_value_t *enter_stmt = jl_array_ptr_ref(ctx.code, ((jl_ssavalue_t*)arg)->id - 1); + size_t enter_idx = ((jl_ssavalue_t*)arg)->id - 1; + jl_value_t *enter_stmt = jl_array_ptr_ref(ctx.code, enter_idx); if (enter_stmt == jl_nothing) continue; + if (ctx.scope_restore.count(enter_idx)) + std::tie(scope_to_restore, scope_ptr) = ctx.scope_restore[enter_idx]; hand_n_leave += 1; } ctx.builder.CreateCall(prepare_call(jlleave_func), ConstantInt::get(getInt32Ty(ctx.builder.getContext()), hand_n_leave)); + if (scope_to_restore) { + jl_aliasinfo_t scope_ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe); + scope_ai.decorateInst( + ctx.builder.CreateAlignedStore(scope_to_restore, scope_ptr, ctx.types().alignof_ptr)); + } } else if (head == jl_pop_exception_sym) { jl_cgval_t excstack_state = emit_expr(ctx, jl_exprarg(expr, 0)); @@ -6059,7 +6070,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_ return jl_cgval_t(); } else if (head == jl_leave_sym || head == jl_coverageeffect_sym - || head == jl_pop_exception_sym || head == jl_enter_sym || head == jl_inbounds_sym + || head == jl_pop_exception_sym || head == jl_inbounds_sym || head == jl_aliasscope_sym || head == jl_popaliasscope_sym || head == jl_inline_sym || head == jl_noinline_sym) { jl_errorf("Expr(:%s) in value position", jl_symbol_name(head)); } @@ -6150,6 +6161,16 @@ static Value *get_last_age_field(jl_codectx_t &ctx) "world_age"); } +static Value *get_scope_field(jl_codectx_t &ctx) +{ + Value *ct = get_current_task(ctx); + return ctx.builder.CreateInBoundsGEP( + ctx.types().T_prjlvalue, + ctx.builder.CreateBitCast(ct, ctx.types().T_prjlvalue->getPointerTo()), + ConstantInt::get(ctx.types().T_size, offsetof(jl_task_t, scope) / ctx.types().sizeof_ptr), + "current_scope"); +} + static Function *emit_tojlinvoke(jl_code_instance_t *codeinst, Module *M, jl_codegen_params_t ¶ms) { ++EmittedToJLInvokes; @@ -8730,6 +8751,20 @@ static jl_llvm_functions_t continue; } else if (jl_is_enternode(stmt)) { + // For the two-arg version of :enter, twiddle the scope + Value *scope_ptr = NULL; + Value *old_scope = NULL; + jl_aliasinfo_t scope_ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe); + if (jl_enternode_scope(stmt)) { + jl_cgval_t new_scope = emit_expr(ctx, jl_enternode_scope(stmt)); + Value *new_scope_boxed = boxed(ctx, new_scope); + scope_ptr = get_scope_field(ctx); + old_scope = scope_ai.decorateInst( + ctx.builder.CreateAlignedLoad(ctx.types().T_prjlvalue, scope_ptr, ctx.types().alignof_ptr)); + scope_ai.decorateInst( + ctx.builder.CreateAlignedStore(new_scope_boxed, scope_ptr, ctx.types().alignof_ptr)); + ctx.scope_restore[cursor] = std::make_pair(old_scope, scope_ptr); + } int lname = jl_enternode_catch_dest(stmt); // Save exception stack depth at enter for use in pop_exception Value *excstack_state = @@ -8737,6 +8772,7 @@ static jl_llvm_functions_t assert(!ctx.ssavalue_assigned[cursor]); ctx.SAvalues[cursor] = jl_cgval_t(excstack_state, (jl_value_t*)jl_ulong_type, NULL); ctx.ssavalue_assigned[cursor] = true; + // Actually enter the exception frame CallInst *sj = ctx.builder.CreateCall(prepare_call(except_enter_func)); // We need to mark this on the call site as well. See issue #6757 sj->setCanReturnTwice(); @@ -8749,9 +8785,15 @@ static jl_llvm_functions_t come_from_bb[cursor + 1] = ctx.builder.GetInsertBlock(); ctx.builder.CreateCondBr(isz, tryblk, catchpop); ctx.builder.SetInsertPoint(catchpop); - ctx.builder.CreateCall(prepare_call(jlleave_func), - ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 1)); - ctx.builder.CreateBr(handlr); + { + ctx.builder.CreateCall(prepare_call(jlleave_func), + ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 1)); + if (old_scope) { + scope_ai.decorateInst( + ctx.builder.CreateAlignedStore(old_scope, scope_ptr, ctx.types().alignof_ptr)); + } + ctx.builder.CreateBr(handlr); + } ctx.builder.SetInsertPoint(tryblk); } else { diff --git a/src/interpreter.c b/src/interpreter.c index 313f5d9423fcc..81096413c8e73 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -519,8 +519,23 @@ static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, size_t ip, } // store current top of exception stack for restore in pop_exception. s->locals[jl_source_nslots(s->src) + ip] = jl_box_ulong(jl_excstack_state()); - if (!jl_setjmp(__eh.eh_ctx, 1)) { - return eval_body(stmts, s, next_ip, toplevel); + if (jl_enternode_scope(stmt)) { + jl_value_t *old_scope = ct->scope; + JL_GC_PUSH1(&old_scope); + jl_value_t *new_scope = eval_value(jl_enternode_scope(stmt), s); + ct->scope = new_scope; + if (!jl_setjmp(__eh.eh_ctx, 1)) { + eval_body(stmts, s, next_ip, toplevel); + jl_unreachable(); + } + ct->scope = old_scope; + JL_GC_POP(); + } + else { + if (!jl_setjmp(__eh.eh_ctx, 1)) { + eval_body(stmts, s, next_ip, toplevel); + jl_unreachable(); + } } jl_eh_restore_state(&__eh); if (s->continue_at) { // means we reached a :leave expression diff --git a/src/jltypes.c b/src/jltypes.c index 7fb7a6dc13eb5..f0f3b36951a2b 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3070,8 +3070,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_enternode_type = jl_new_datatype(jl_symbol("EnterNode"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(1, "catch_dest"), - jl_svec(1, jl_long_type), + jl_perm_symsvec(2, "catch_dest", "scope"), + jl_svec(2, jl_long_type, jl_any_type), jl_emptysvec, 0, 0, 1); jl_returnnode_type = diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 76a06c6d1a47b..e7899688453c7 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -4718,7 +4718,9 @@ f(x) = yt(x) #f)) ;; exception handlers are lowered using - ;; (= tok (enter L)) - push handler with catch block at label L, yielding token + ;; (= tok (enter L scope)) + ;; push handler with catch block at label L and scope `scope`, yielding token + ;; `scope` is only recognized for tryfinally and may be omitted in the lowering ;; (leave n) - pop N exception handlers ;; (pop_exception tok) - pop exception stack back to state of associated enter ((trycatch tryfinally trycatchelse) @@ -4728,9 +4730,10 @@ f(x) = yt(x) (endl (make-label)) (last-finally-handler finally-handler) (finally (if (eq? (car e) 'tryfinally) (new-mutable-var) #f)) + (scope (if (eq? (car e) 'tryfinally) (cdddr e) '())) (my-finally-handler #f)) ;; handler block entry - (emit `(= ,handler-token (enter ,catch))) + (emit `(= ,handler-token (enter ,catch ,@(compile-args scope break-labels)))) (set! handler-token-stack (cons handler-token handler-token-stack)) (if finally (begin (set! my-finally-handler (list finally endl '() handler-token-stack catch-token-stack)) (set! finally-handler my-finally-handler) @@ -5104,8 +5107,10 @@ f(x) = yt(x) (let ((idx (get ssavalue-table (cadr e) #f))) (if (not idx) (begin (prn e) (prn lam) (error "ssavalue with no def"))) `(ssavalue ,idx))) - ((memq (car e) '(goto enter)) - (list* (car e) (get label-table (cadr e)) (cddr e))) + ((eq? (car e) 'goto) + `(goto ,(get label-table (cadr e)))) + ((eq? (car e) 'enter) + `(enter ,(get label-table (cadr e)) ,@(map renumber-stuff (cddr e)))) ((eq? (car e) 'gotoifnot) `(gotoifnot ,(renumber-stuff (cadr e)) ,(get label-table (caddr e)))) ((eq? (car e) 'lambda) diff --git a/src/julia.h b/src/julia.h index 7b148ecbec55d..c5aa659670612 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1282,6 +1282,7 @@ STATIC_INLINE void jl_array_uint32_set(void *a, size_t i, uint32_t x) JL_NOTSAFE #define jl_gotoifnot_cond(x) (((jl_value_t**)(x))[0]) #define jl_gotoifnot_label(x) (((intptr_t*)(x))[1]) #define jl_enternode_catch_dest(x) (((intptr_t*)(x))[0]) +#define jl_enternode_scope(x) (((jl_value_t**)(x))[1]) #define jl_globalref_mod(s) (*(jl_module_t**)(s)) #define jl_globalref_name(s) (((jl_sym_t**)(s))[1]) #define jl_quotenode_value(x) (((jl_value_t**)x)[0]) diff --git a/src/method.c b/src/method.c index 73d2d256320ab..30bf9c5774f11 100644 --- a/src/method.c +++ b/src/method.c @@ -64,6 +64,21 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve } return expr; } + else if (jl_is_enternode(expr)) { + jl_value_t *scope = jl_enternode_scope(expr); + if (scope) { + jl_value_t *val = resolve_globals(scope, module, sparam_vals, binding_effects, eager_resolve); + if (val != scope) { + intptr_t catch_dest = jl_enternode_catch_dest(expr); + JL_GC_PUSH1(&val); + expr = jl_new_struct_uninit(jl_enternode_type); + jl_enternode_catch_dest(expr) = catch_dest; + jl_enternode_scope(expr) = val; + JL_GC_POP(); + } + } + return expr; + } else if (jl_is_gotoifnot(expr)) { jl_value_t *cond = resolve_globals(jl_gotoifnot_cond(expr), module, sparam_vals, binding_effects, eager_resolve); if (cond != jl_gotoifnot_cond(expr)) { diff --git a/src/staticdata.c b/src/staticdata.c index 8befca1d93414..8244c5a4373ce 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -470,6 +470,7 @@ static const jl_fptr_args_t id_to_fptrs[] = { &jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, &jl_f_get_binding_type, &jl_f_set_binding_type, &jl_f_opaque_closure_call, &jl_f_donotdelete, &jl_f_compilerbarrier, &jl_f_getglobal, &jl_f_setglobal, &jl_f_finalizer, &jl_f__compute_sparams, &jl_f__svec_ref, + &jl_f_current_scope, NULL }; typedef struct { diff --git a/test/core.jl b/test/core.jl index 7cd4be5427eb4..25264c689fcfb 100644 --- a/test/core.jl +++ b/test/core.jl @@ -8070,3 +8070,16 @@ let widen_diagonal(x::UnionAll) = Base.rewrap_unionall(Base.widen_diagonal(Base. @test Union{Tuple{T}, Tuple{T,Int}} where {T} === widen_diagonal(Union{Tuple{T}, Tuple{T,Int}} where {T}) @test Tuple === widen_diagonal(Union{Tuple{Vararg{S}}, Tuple{Vararg{T}}} where {S, T}) end + +# Test try/catch/else ordering +function test_try_catch_else() + local x + try + x = 1 + catch + rethrow() + else + return x + end +end +@test test_try_catch_else() == 1 diff --git a/test/scopedvalues.jl b/test/scopedvalues.jl index b2a0e574fe7aa..b1f3241af8fc6 100644 --- a/test/scopedvalues.jl +++ b/test/scopedvalues.jl @@ -69,11 +69,11 @@ end @testset "show" begin @test sprint(show, ScopedValue{Int}()) == "ScopedValue{$Int}(undefined)" @test sprint(show, sval) == "ScopedValue{$Int}(1)" - @test sprint(show, ScopedValues.current_scope()) == "nothing" + @test sprint(show, Core.current_scope()) == "nothing" with(sval => 2.0) do @test sprint(show, sval) == "ScopedValue{$Int}(2)" objid = sprint(show, Base.objectid(sval)) - @test sprint(show, ScopedValues.current_scope()) == "Base.ScopedValues.Scope(ScopedValue{$Int}@$objid => 2)" + @test sprint(show, Core.current_scope()) == "Base.ScopedValues.Scope(ScopedValue{$Int}@$objid => 2)" end end