Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid using @sync_add on remotecalls #44671

Merged
merged 7 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,12 @@ isolating the asynchronous code from changes to the variable's value in the curr
Interpolating values via `\$` is available as of Julia 1.4.
"""
macro async(expr)
do_async_macro(expr)
end

# generate the code for @async, possibly wrapping the task in something before
# pushing it to the wait queue.
function do_async_macro(expr; wrap=identity)
letargs = Base._lift_one_interp!(expr)

thunk = esc(:(()->($expr)))
Expand All @@ -479,14 +485,43 @@ macro async(expr)
let $(letargs...)
local task = Task($thunk)
if $(Expr(:islocal, var))
put!($var, task)
put!($var, $(wrap(:task)))
end
schedule(task)
task
end
end
end

# task wrapper that doesn't create exceptions wrapped in TaskFailedException
struct UnwrapTaskFailedException
task::Task
end

# common code for wait&fetch for UnwrapTaskFailedException
function unwrap_task_failed(f::Function, t::UnwrapTaskFailedException)
try
f(t.task)
catch ex
if ex isa TaskFailedException
throw(ex.task.exception)
else
rethrow()
end
end
end

# the unwrapping for above task wrapper (gets triggered in sync_end())
wait(t::UnwrapTaskFailedException) = unwrap_task_failed(wait, t)

# same for fetching the tasks, for convenience
fetch(t::UnwrapTaskFailedException) = unwrap_task_failed(fetch, t)

# macro for running async code that doesn't throw wrapped exceptions
macro async_unwrap(expr)
do_async_macro(expr, wrap=task->:(Base.UnwrapTaskFailedException($task)))
end

"""
errormonitor(t::Task)

Expand Down
4 changes: 2 additions & 2 deletions stdlib/Distributed/src/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Base: getindex, wait, put!, take!, fetch, isready, push!, length,
hash, ==, kill, close, isopen, showerror

# imports for use
using Base: Process, Semaphore, JLOptions, buffer_writes, @sync_add,
using Base: Process, Semaphore, JLOptions, buffer_writes, @async_unwrap,
VERSION_STRING, binding_module, atexit, julia_exename,
julia_cmd, AsyncGenerator, acquire, release, invokelatest,
shell_escape_posixly, shell_escape_csh,
Expand Down Expand Up @@ -76,7 +76,7 @@ function _require_callback(mod::Base.PkgId)
# broadcast top-level (e.g. from Main) import/using from node 1 (only)
@sync for p in procs()
p == 1 && continue
@sync_add remotecall(p) do
@async_unwrap remotecall_wait(p) do
Base.require(mod)
nothing
end
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Distributed/src/clusterserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ An exception is raised if a global constant is requested to be cleared.
"""
function clear!(syms, pids=workers(); mod=Main)
@sync for p in pids
@sync_add remotecall(clear_impl!, p, syms, mod)
@async_unwrap remotecall_wait(clear_impl!, p, syms, mod)
end
end
clear!(sym::Symbol, pid::Int; mod=Main) = clear!([sym], [pid]; mod=mod)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/Distributed/src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ function remotecall_eval(m::Module, procs, ex)
if pid == myid()
run_locally += 1
else
@sync_add remotecall(Core.eval, pid, m, ex)
@async_unwrap remotecall_wait(Core.eval, pid, m, ex)
end
end
yield() # ensure that the remotecall_fetch have had a chance to start
yield() # ensure that the remotecalls have had a chance to start

# execute locally last as we do not want local execution to block serialization
# of the request to remote nodes.
Expand Down