Skip to content

Commit

Permalink
make at-sync thread safe
Browse files Browse the repository at this point in the history
fixes #34666
  • Loading branch information
JeffBezanson committed Apr 29, 2020
1 parent bff96e2 commit 14ba441
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 34 deletions.
55 changes: 31 additions & 24 deletions base/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,37 @@ macro aliasscope(body)
end


function sync_end(refs)
local c_ex
defined = false
t = current_task()
function sync_end(refs::Vector{Any}, l::Threads.SpinLock)
cond = Threads.Condition()
lock(cond)
nremaining = length(refs)
for r in refs
schedule(Task(()->begin
try
wait(r)
lock(cond)
nremaining -= 1
nremaining == 0 && notify(cond)
unlock(cond)
catch e
lock(cond)
notify(cond, e; error=true)
unlock(cond)
end
end))
while true
lock(l)
if isempty(refs)
unlock(l)
return
end
localrefs = copy(refs)
empty!(refs)
unlock(l)
lock(cond)
nremaining::Int = length(localrefs)
for r in localrefs
schedule(Task(()->begin
try
wait(r)
lock(cond)
nremaining -= 1
nremaining == 0 && notify(cond)
unlock(cond)
catch e
lock(cond)
notify(cond, e; error=true)
unlock(cond)
end
end))
end
wait(cond)
unlock(cond)
end
wait(cond)
unlock(cond)
end

"""
Expand All @@ -92,9 +99,9 @@ during error handling.
macro sync(block)
var = esc(sync_varname)
quote
let $var = Any[]
let $var = (Any[], Threads.SpinLock())
v = $(esc(block))
sync_end($var)
sync_end($var[1], $var[2])
v
end
end
Expand Down
24 changes: 18 additions & 6 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,18 @@ end

## lexically-scoped waiting for multiple items

function sync_end(refs)
function sync_end(refs::Vector{Any}, l::Threads.SpinLock)
local c_ex
defined = false
for r in refs
while true
lock(l)
if isempty(refs)
unlock(l)
break
else
r = popfirst!(refs)
unlock(l)
end
if isa(r, Task)
_wait(r)
if istaskfailed(r)
Expand Down Expand Up @@ -330,9 +338,9 @@ a `CompositeException`.
macro sync(block)
var = esc(sync_varname)
quote
let $var = Any[]
let $var = (Any[], Threads.SpinLock())
v = $(esc(block))
sync_end($var)
sync_end($var[1], $var[2])
v
end
end
Expand Down Expand Up @@ -361,7 +369,9 @@ macro async(expr)
let $(letargs...)
local task = Task($thunk)
if $(Expr(:islocal, var))
push!($var, task)
lock($var[2])
push!($var[1], task)
unlock($var[2])
end
schedule(task)
task
Expand Down Expand Up @@ -403,7 +413,9 @@ macro sync_add(expr)
var = esc(sync_varname)
quote
local ref = $(esc(expr))
push!($var, ref)
lock($var[2])
push!($var[1], ref)
unlock($var[2])
ref
end
end
Expand Down
4 changes: 3 additions & 1 deletion base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ macro spawn(expr)
local task = Task($thunk)
task.sticky = false
if $(Expr(:islocal, var))
push!($var, task)
lock($var[2])
push!($var[1], task)
unlock($var[2])
end
schedule(task)
task
Expand Down
12 changes: 9 additions & 3 deletions stdlib/Distributed/src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ macro spawn(expr)
quote
local ref = spawn_somewhere($thunk)
if $(Expr(:islocal, var))
push!($var, ref)
lock($var[2])
push!($var[1], ref)
unlock($var[2])
end
ref
end
Expand Down Expand Up @@ -94,7 +96,9 @@ macro spawnat(p, expr)
quote
local ref = $spawncall
if $(Expr(:islocal, var))
push!($var, ref)
lock($var[2])
push!($var[1], ref)
unlock($var[2])
end
ref
end
Expand Down Expand Up @@ -345,7 +349,9 @@ macro distributed(args...)
return quote
local ref = pfor($(make_pfor_body(var, body)), $(esc(r)))
if $(Expr(:islocal, syncvar))
push!($syncvar, ref)
lock($syncvar[2])
push!($syncvar[1], ref)
unlock($syncvar[2])
end
ref
end
Expand Down
13 changes: 13 additions & 0 deletions test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -821,3 +821,16 @@ end
x = 2
@test @eval(fetch(@async 2+$x)) == 4
end

# issue #34666
fib34666(x) =
@sync begin
function f(x)
x in (0, 1) && return x
a = Threads.@spawn f(x - 2)
b = Threads.@spawn f(x - 1)
return fetch(a) + fetch(b)
end
f(x)
end
@test fib34666(25) == 75025

0 comments on commit 14ba441

Please sign in to comment.