diff --git a/base/task.jl b/base/task.jl index 54ae19e48a833..14600fa56b648 100644 --- a/base/task.jl +++ b/base/task.jl @@ -615,6 +615,8 @@ function yield() end end +@inline set_next_task(t::Task) = ccall(:jl_set_next_task, Cvoid, (Any,), t) + """ yield(t::Task, arg = nothing) @@ -624,7 +626,8 @@ immediately yields to `t` before calling the scheduler. function yield(t::Task, @nospecialize(x=nothing)) t.result = x enq_work(current_task()) - return try_yieldto(ensure_rescheduled, Ref(t)) + set_next_task(t) + return try_yieldto(ensure_rescheduled) end """ @@ -637,14 +640,15 @@ or scheduling in any way. Its use is discouraged. """ function yieldto(t::Task, @nospecialize(x=nothing)) t.result = x - return try_yieldto(identity, Ref(t)) + set_next_task(t) + return try_yieldto(identity) end -function try_yieldto(undo, reftask::Ref{Task}) +function try_yieldto(undo) try - ccall(:jl_switchto, Cvoid, (Any,), reftask) + ccall(:jl_switch, Cvoid, ()) catch - undo(reftask[]) + undo(ccall(:jl_get_next_task, Ref{Task}, ())) rethrow() end ct = current_task() @@ -696,18 +700,19 @@ function trypoptask(W::StickyWorkqueue) return t end -@noinline function poptaskref(W::StickyWorkqueue) +@noinline function poptask(W::StickyWorkqueue) task = trypoptask(W) if !(task isa Task) task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any), trypoptask, W) end - return Ref(task) + set_next_task(task) + nothing end function wait() W = Workqueues[Threads.threadid()] - reftask = poptaskref(W) - result = try_yieldto(ensure_rescheduled, reftask) + poptask(W) + result = try_yieldto(ensure_rescheduled) process_events() # return when we come out of the queue return result diff --git a/src/ccall.cpp b/src/ccall.cpp index 5cbf5ba02c369..3b76150207c02 100644 --- a/src/ccall.cpp +++ b/src/ccall.cpp @@ -1619,6 +1619,16 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs) tbaa_decorate(tbaa_const, ctx.builder.CreateLoad(pct)), retboxed, rt, unionall, static_rt); } + else if (is_libjulia_func(jl_set_next_task)) { + assert(lrt == T_void); + assert(!isVa && !llvmcall && nccallargs == 1); + JL_GC_POP(); + Value *ptls_pv = emit_bitcast(ctx, ctx.ptlsStates, T_ppjlvalue); + const int nt_offset = offsetof(jl_tls_states_t, next_task); + Value *pnt = ctx.builder.CreateGEP(ptls_pv, ConstantInt::get(T_size, nt_offset / sizeof(void*))); + ctx.builder.CreateStore(emit_pointer_from_objref(ctx, boxed(ctx, argv[0])), pnt); + return ghostValue(jl_nothing_type); + } else if (is_libjulia_func(jl_sigatomic_begin)) { assert(lrt == T_void); assert(!isVa && !llvmcall && nccallargs == 0); diff --git a/src/gc.c b/src/gc.c index 5b81785130ed4..d5aa9a9b5b719 100644 --- a/src/gc.c +++ b/src/gc.c @@ -2645,6 +2645,8 @@ static void jl_gc_queue_thread_local(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp { gc_mark_queue_obj(gc_cache, sp, ptls2->current_task); gc_mark_queue_obj(gc_cache, sp, ptls2->root_task); + if (ptls2->next_task) + gc_mark_queue_obj(gc_cache, sp, ptls2->next_task); if (ptls2->previous_exception) gc_mark_queue_obj(gc_cache, sp, ptls2->previous_exception); } diff --git a/src/julia_internal.h b/src/julia_internal.h index 39c5ad3dbbd86..6360e108de928 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1011,6 +1011,7 @@ JL_DLLEXPORT int jl_array_isassigned(jl_array_t *a, size_t i); JL_DLLEXPORT uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_value_t *jl_get_current_task(void); +JL_DLLEXPORT void jl_set_next_task(jl_task_t *task); // -- synchronization utilities -- // diff --git a/src/julia_threads.h b/src/julia_threads.h index 9e3216c793894..6040a593300c1 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -185,6 +185,7 @@ struct _jl_tls_states_t { uv_cond_t wake_signal; volatile sig_atomic_t defer_signal; struct _jl_task_t *current_task; + struct _jl_task_t *next_task; #ifdef MIGRATE_TASKS struct _jl_task_t *previous_task; #endif diff --git a/src/task.c b/src/task.c index ff7c87d1dd1d2..9d88306dd4eee 100644 --- a/src/task.c +++ b/src/task.c @@ -125,7 +125,7 @@ static void NOINLINE save_stack(jl_ptls_t ptls, jl_task_t *lastt, jl_task_t **pt else { buf = lastt->stkbuf; } - *pt = lastt; // clear the gc-root for the target task before copying the stack for saving + *pt = NULL; // clear the gc-root for the target task before copying the stack for saving lastt->copy_stack = nb; lastt->sticky = 1; memcpy_a16((uint64_t*)buf, (uint64_t*)frame_addr, nb); @@ -248,10 +248,24 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel) _julia_init(rel); } +JL_DLLEXPORT void jl_set_next_task(jl_task_t *task) +{ + jl_get_ptls_states()->next_task = task; +} + +JL_DLLEXPORT jl_task_t *jl_get_next_task(void) +{ + jl_ptls_t ptls = jl_get_ptls_states(); + if (ptls->next_task) + return ptls->next_task; + return ptls->current_task; +} + void jl_release_task_stack(jl_ptls_t ptls, jl_task_t *task); -static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) +static void ctx_switch(jl_ptls_t ptls) { + jl_task_t **pt = &ptls->next_task; jl_task_t *t = *pt; assert(t != ptls->current_task); jl_task_t *lastt = ptls->current_task; @@ -283,7 +297,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) } if (killed) { - *pt = lastt; // can't fail after here: clear the gc-root for the target task now + *pt = NULL; // can't fail after here: clear the gc-root for the target task now lastt->gcstack = NULL; if (!lastt->copy_stack && lastt->stkbuf) { // early free of stkbuf back to the pool @@ -302,7 +316,7 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) } else #endif - *pt = lastt; // can't fail after here: clear the gc-root for the target task now + *pt = NULL; // can't fail after here: clear the gc-root for the target task now lastt->gcstack = ptls->pgcstack; } @@ -366,10 +380,10 @@ static jl_ptls_t NOINLINE refetch_ptls(void) return jl_get_ptls_states(); } -JL_DLLEXPORT void jl_switchto(jl_task_t **pt) +JL_DLLEXPORT void jl_switch(void) { jl_ptls_t ptls = jl_get_ptls_states(); - jl_task_t *t = *pt; + jl_task_t *t = ptls->next_task; jl_task_t *ct = ptls->current_task; if (t == ct) { return; @@ -401,7 +415,7 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt) jl_timing_block_stop(blk); #endif - ctx_switch(ptls, pt); + ctx_switch(ptls); #ifdef MIGRATE_TASKS ptls = refetch_ptls(); @@ -432,6 +446,12 @@ JL_DLLEXPORT void jl_switchto(jl_task_t **pt) jl_sigint_safepoint(ptls); } +JL_DLLEXPORT void jl_switchto(jl_task_t **pt) +{ + jl_set_next_task(*pt); + jl_switch(); +} + JL_DLLEXPORT JL_NORETURN void jl_no_exc_handler(jl_value_t *e) { jl_printf(JL_STDERR, "fatal: error thrown and no exception handler available.\n"); diff --git a/src/threading.c b/src/threading.c index d72663d2fdb4f..3eadcc48a5dae 100644 --- a/src/threading.c +++ b/src/threading.c @@ -285,6 +285,7 @@ void jl_init_threadtls(int16_t tid) ptls->bt_data = bt_data; ptls->sig_exception = NULL; ptls->previous_exception = NULL; + ptls->next_task = NULL; #ifdef _OS_WINDOWS_ ptls->needs_resetstkoflw = 0; #endif