From 14ba441193651cd6b35a1ce53cf7601fe0b8fb2b Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Wed, 29 Apr 2020 11:06:39 -0400 Subject: [PATCH] make at-sync thread safe fixes #34666 --- base/experimental.jl | 55 ++++++++++++++++++-------------- base/task.jl | 24 ++++++++++---- base/threadingconstructs.jl | 4 ++- stdlib/Distributed/src/macros.jl | 12 +++++-- test/threads_exec.jl | 13 ++++++++ 5 files changed, 74 insertions(+), 34 deletions(-) diff --git a/base/experimental.jl b/base/experimental.jl index 910deb3f6fcac..bf413f18e068a 100644 --- a/base/experimental.jl +++ b/base/experimental.jl @@ -52,30 +52,37 @@ macro aliasscope(body) end -function sync_end(refs) - local c_ex - defined = false - t = current_task() +function sync_end(refs::Vector{Any}, l::Threads.SpinLock) cond = Threads.Condition() - lock(cond) - nremaining = length(refs) - for r in refs - schedule(Task(()->begin - try - wait(r) - lock(cond) - nremaining -= 1 - nremaining == 0 && notify(cond) - unlock(cond) - catch e - lock(cond) - notify(cond, e; error=true) - unlock(cond) - end - end)) + while true + lock(l) + if isempty(refs) + unlock(l) + return + end + localrefs = copy(refs) + empty!(refs) + unlock(l) + lock(cond) + nremaining::Int = length(localrefs) + for r in localrefs + schedule(Task(()->begin + try + wait(r) + lock(cond) + nremaining -= 1 + nremaining == 0 && notify(cond) + unlock(cond) + catch e + lock(cond) + notify(cond, e; error=true) + unlock(cond) + end + end)) + end + wait(cond) + unlock(cond) end - wait(cond) - unlock(cond) end """ @@ -92,9 +99,9 @@ during error handling. macro sync(block) var = esc(sync_varname) quote - let $var = Any[] + let $var = (Any[], Threads.SpinLock()) v = $(esc(block)) - sync_end($var) + sync_end($var[1], $var[2]) v end end diff --git a/base/task.jl b/base/task.jl index 14600fa56b648..7a8516e3fb1ce 100644 --- a/base/task.jl +++ b/base/task.jl @@ -286,10 +286,18 @@ end ## lexically-scoped waiting for multiple items -function sync_end(refs) +function sync_end(refs::Vector{Any}, l::Threads.SpinLock) local c_ex defined = false - for r in refs + while true + lock(l) + if isempty(refs) + unlock(l) + break + else + r = popfirst!(refs) + unlock(l) + end if isa(r, Task) _wait(r) if istaskfailed(r) @@ -330,9 +338,9 @@ a `CompositeException`. macro sync(block) var = esc(sync_varname) quote - let $var = Any[] + let $var = (Any[], Threads.SpinLock()) v = $(esc(block)) - sync_end($var) + sync_end($var[1], $var[2]) v end end @@ -361,7 +369,9 @@ macro async(expr) let $(letargs...) local task = Task($thunk) if $(Expr(:islocal, var)) - push!($var, task) + lock($var[2]) + push!($var[1], task) + unlock($var[2]) end schedule(task) task @@ -403,7 +413,9 @@ macro sync_add(expr) var = esc(sync_varname) quote local ref = $(esc(expr)) - push!($var, ref) + lock($var[2]) + push!($var[1], ref) + unlock($var[2]) ref end end diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index b346786b2c459..312d6ce73cc39 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -130,7 +130,9 @@ macro spawn(expr) local task = Task($thunk) task.sticky = false if $(Expr(:islocal, var)) - push!($var, task) + lock($var[2]) + push!($var[1], task) + unlock($var[2]) end schedule(task) task diff --git a/stdlib/Distributed/src/macros.jl b/stdlib/Distributed/src/macros.jl index fac51d21766e0..1582ca8e9d985 100644 --- a/stdlib/Distributed/src/macros.jl +++ b/stdlib/Distributed/src/macros.jl @@ -49,7 +49,9 @@ macro spawn(expr) quote local ref = spawn_somewhere($thunk) if $(Expr(:islocal, var)) - push!($var, ref) + lock($var[2]) + push!($var[1], ref) + unlock($var[2]) end ref end @@ -94,7 +96,9 @@ macro spawnat(p, expr) quote local ref = $spawncall if $(Expr(:islocal, var)) - push!($var, ref) + lock($var[2]) + push!($var[1], ref) + unlock($var[2]) end ref end @@ -345,7 +349,9 @@ macro distributed(args...) return quote local ref = pfor($(make_pfor_body(var, body)), $(esc(r))) if $(Expr(:islocal, syncvar)) - push!($syncvar, ref) + lock($syncvar[2]) + push!($syncvar[1], ref) + unlock($syncvar[2]) end ref end diff --git a/test/threads_exec.jl b/test/threads_exec.jl index 525d1face587f..ef3d85aec7a55 100644 --- a/test/threads_exec.jl +++ b/test/threads_exec.jl @@ -821,3 +821,16 @@ end x = 2 @test @eval(fetch(@async 2+$x)) == 4 end + +# issue #34666 +fib34666(x) = + @sync begin + function f(x) + x in (0, 1) && return x + a = Threads.@spawn f(x - 2) + b = Threads.@spawn f(x - 1) + return fetch(a) + fetch(b) + end + f(x) + end +@test fib34666(25) == 75025