-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: parallel task runtime #22631
WIP: parallel task runtime #22631
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,10 @@ | |
|
||
abstract type AbstractChannel{T} end | ||
|
||
if JULIA_PARTR | ||
|
||
using Base.Threads | ||
|
||
""" | ||
Channel{T}(sz::Int) | ||
|
||
|
@@ -21,7 +25,54 @@ mutable struct Channel{T} <: AbstractChannel{T} | |
cond_take::Condition # waiting for data to become available | ||
cond_put::Condition # waiting for a writeable slot | ||
state::Symbol | ||
excp::Union{Exception, Nothing} # exception to be thrown when state != :open | ||
excp::Union{Exception,Nothing} # exception to be thrown when state != :open | ||
|
||
data::Vector{T} | ||
sz_max::Int # maximum size of channel | ||
lock::SpinLock | ||
|
||
# The following fields synchronize tasks that use unbuffered channels | ||
# (sz_max == 0). | ||
nwaiters::Atomic{Int} | ||
takers::Vector{Task} | ||
putters::Vector{Task} | ||
|
||
function Channel{T}(sz::Float64) where T | ||
Channel{T}(sz == Inf ? typemax(Int) : convert(Int, sz)) | ||
end | ||
function Channel{T}(sz::Integer) where T | ||
sz < 0 && throw(ArgumentError("Channel size must be 0, a positive integer, or Inf")) | ||
ch = new(Condition(), Condition(), :open, nothing, Vector{T}(), sz, SpinLock(), Atomic()) | ||
if sz == 0 | ||
ch.takers = Vector{Task}() | ||
ch.putters = Vector{Task}() | ||
end | ||
return ch | ||
end | ||
end | ||
|
||
else # !JULIA_PARTR | ||
|
||
""" | ||
Channel{T}(sz::Int) | ||
|
||
Constructs a `Channel` with an internal buffer that can hold a maximum of `sz` objects | ||
of type `T`. | ||
[`put!`](@ref) calls on a full channel block until an object is removed with [`take!`](@ref). | ||
|
||
`Channel(0)` constructs an unbuffered channel. `put!` blocks until a matching `take!` is called. | ||
And vice-versa. | ||
|
||
Other constructors: | ||
|
||
* `Channel(Inf)`: equivalent to `Channel{Any}(typemax(Int))` | ||
* `Channel(sz)`: equivalent to `Channel{Any}(sz)` | ||
""" | ||
mutable struct Channel{T} <: AbstractChannel{T} | ||
cond_take::Condition # waiting for data to become available | ||
cond_put::Condition # waiting for a writeable slot | ||
state::Symbol | ||
excp::Union{Exception, Nothing} # exception to be thrown when state != :open | ||
|
||
data::Vector{T} | ||
sz_max::Int # maximum size of channel | ||
|
@@ -51,6 +102,8 @@ mutable struct Channel{T} <: AbstractChannel{T} | |
end | ||
end | ||
|
||
end # !JULIA_PARTR | ||
|
||
Channel(sz) = Channel{Any}(sz) | ||
|
||
# special constructors | ||
|
@@ -88,13 +141,13 @@ Referencing the created task: | |
```jldoctest | ||
julia> taskref = Ref{Task}(); | ||
|
||
julia> chnl = Channel(c->(@show take!(c)); taskref=taskref); | ||
julia> chnl = Channel(c->println(take!(c)); taskref=taskref); | ||
|
||
julia> istaskdone(taskref[]) | ||
false | ||
|
||
julia> put!(chnl, "Hello"); | ||
take!(c) = "Hello" | ||
Hello | ||
|
||
julia> istaskdone(taskref[]) | ||
true | ||
|
@@ -110,7 +163,6 @@ function Channel(func::Function; ctype=Any, csize=0, taskref=nothing) | |
return chnl | ||
end | ||
|
||
|
||
closed_exception() = InvalidStateException("Channel is closed.", :closed) | ||
|
||
isbuffered(c::Channel) = c.sz_max==0 ? false : true | ||
|
@@ -121,6 +173,7 @@ function check_channel_state(c::Channel) | |
throw(closed_exception()) | ||
end | ||
end | ||
|
||
""" | ||
close(c::Channel) | ||
|
||
|
@@ -255,6 +308,25 @@ function put!(c::Channel{T}, v) where T | |
isbuffered(c) ? put_buffered(c,v) : put_unbuffered(c,v) | ||
end | ||
|
||
if JULIA_PARTR | ||
|
||
function put_buffered(c::Channel, v) | ||
while true | ||
lock(c.lock) | ||
if length(c.data) == c.sz_max | ||
unlock(c.lock) | ||
wait(c.cond_put) | ||
else | ||
push!(c.data, v) | ||
notify(c.cond_take, nothing, true, false) | ||
unlock(c.lock) | ||
return v | ||
end | ||
end | ||
end | ||
|
||
else # !JULIA_PARTR | ||
|
||
function put_buffered(c::Channel, v) | ||
while length(c.data) == c.sz_max | ||
wait(c.cond_put) | ||
|
@@ -266,6 +338,28 @@ function put_buffered(c::Channel, v) | |
v | ||
end | ||
|
||
end # !JULIA_PARTR | ||
|
||
if JULIA_PARTR | ||
|
||
function put_unbuffered(c::Channel, v) | ||
while true | ||
lock(c.lock) | ||
if length(c.takers) > 0 | ||
taker = popfirst!(c.takers) | ||
unlock(c.lock) | ||
yield(taker, v) | ||
return v | ||
else | ||
unlock(c.lock) | ||
c.nwaiters[] > 0 && notify(c.cond_take, nothing, false, false) | ||
wait(c.cond_put) | ||
end | ||
end | ||
end | ||
|
||
else # !JULIA_PARTR | ||
|
||
function put_unbuffered(c::Channel, v) | ||
if length(c.takers) == 0 | ||
push!(c.putters, current_task()) | ||
|
@@ -283,8 +377,37 @@ function put_unbuffered(c::Channel, v) | |
return v | ||
end | ||
|
||
end # !JULIA_PARTR | ||
|
||
push!(c::Channel, v) = put!(c, v) | ||
|
||
if JULIA_PARTR | ||
|
||
""" | ||
fetch(c::Channel) | ||
|
||
Wait for and get the first available item from the channel. Does not | ||
remove the item. `fetch` is unsupported on an unbuffered (0-size) channel. | ||
""" | ||
function fetch(c::Channel) | ||
c.sz_max == 0 && throw(ErrorException("`fetch` is not supported on an unbuffered Channel")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to change this to something other than |
||
while true | ||
check_channel_state(c) | ||
lock(c.lock) | ||
if length(c.data) < 1 | ||
unlock(c.lock) | ||
# TODO: fix the race here | ||
wait(c.cond_take) | ||
else | ||
v = c.data[1] | ||
unlock(c.lock) | ||
return v | ||
end | ||
end | ||
end | ||
|
||
else # !JULIA_PARTR | ||
|
||
""" | ||
fetch(c::Channel) | ||
|
||
|
@@ -298,6 +421,7 @@ function fetch_buffered(c::Channel) | |
end | ||
fetch_unbuffered(c::Channel) = throw(ErrorException("`fetch` is not supported on an unbuffered Channel.")) | ||
|
||
end # !JULIA_PARTR | ||
|
||
""" | ||
take!(c::Channel) | ||
|
@@ -308,14 +432,56 @@ For unbuffered channels, blocks until a [`put!`](@ref) is performed by a differe | |
task. | ||
""" | ||
take!(c::Channel) = isbuffered(c) ? take_buffered(c) : take_unbuffered(c) | ||
|
||
if JULIA_PARTR | ||
|
||
function take_buffered(c::Channel) | ||
while true | ||
lock(c.lock) | ||
if length(c.data) > 0 | ||
v = popfirst!(c.data) | ||
unlock(c.lock) | ||
notify(c.cond_put, nothing, false, false) | ||
return v | ||
end | ||
unlock(c.lock) | ||
check_channel_state(c) | ||
wait(c.cond_take) | ||
end | ||
end | ||
|
||
else # !JULIA_PARTR | ||
|
||
function take_buffered(c::Channel) | ||
wait(c) | ||
v = popfirst!(c.data) | ||
notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!. | ||
v | ||
end | ||
|
||
popfirst!(c::Channel) = take!(c) | ||
end # !JULIA_PARTR | ||
|
||
if JULIA_PARTR | ||
|
||
function take_unbuffered(c::Channel{T}) where T | ||
check_channel_state(c) | ||
lock(c.lock) | ||
push!(c.takers, current_task()) | ||
unlock(c.lock) | ||
notify(c.cond_put, nothing, false, false) | ||
try | ||
# We wait here for a putter which will reschedule us with the | ||
# value it is putting (which is returned by this wait call). | ||
return wait()::T | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it worth commenting this line, to explain that
|
||
catch ex | ||
lock(c.lock) | ||
filter!(x->x!=current_task(), c.takers) | ||
unlock(c.lock) | ||
rethrow(ex) | ||
end | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The locks/unlocks in this section could be written in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a side point, I like that the PATR code for this one is a lot clearer than the nonPATR code :-D |
||
|
||
else # !JULIA_PARTR | ||
|
||
# 0-size channel | ||
function take_unbuffered(c::Channel{T}) where T | ||
|
@@ -338,6 +504,10 @@ function take_unbuffered(c::Channel{T}) where T | |
end | ||
end | ||
|
||
end # !JULIA_PARTR | ||
|
||
popfirst!(c::Channel) = take!(c) | ||
|
||
""" | ||
isready(c::Channel) | ||
|
||
|
@@ -348,7 +518,14 @@ For unbuffered channels returns `true` if there are tasks waiting | |
on a [`put!`](@ref). | ||
""" | ||
isready(c::Channel) = n_avail(c) > 0 | ||
|
||
if JULIA_PARTR | ||
n_avail(c::Channel) = lock(c.lock) do | ||
isbuffered(c) ? length(c.data) : isempty(c.cond_put) ? 0 : 1 | ||
end | ||
else # !JULIA_PARTR | ||
n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.putters) | ||
end # !JULIA_PARTR | ||
|
||
wait(c::Channel) = isbuffered(c) ? wait_impl(c) : wait_unbuffered(c) | ||
function wait_impl(c::Channel) | ||
|
@@ -359,6 +536,17 @@ function wait_impl(c::Channel) | |
nothing | ||
end | ||
|
||
if JULIA_PARTR | ||
function wait_unbuffered(c::Channel) | ||
atomic_add!(c.nwaiters, 1) | ||
try | ||
wait_impl(c) | ||
finally | ||
atomic_sub!(c.nwaiters, 1) | ||
end | ||
nothing | ||
end | ||
else # !JULIA_PARTR | ||
function wait_unbuffered(c::Channel) | ||
c.waiters += 1 | ||
try | ||
|
@@ -368,6 +556,7 @@ function wait_unbuffered(c::Channel) | |
end | ||
nothing | ||
end | ||
end # !JULIA_PARTR | ||
|
||
function notify_error(c::Channel, err) | ||
notify_error(c.cond_take, err) | ||
|
@@ -379,6 +568,7 @@ function notify_error(c::Channel, err) | |
foreach(t->schedule(t, err; error=true), waiters) | ||
end | ||
end | ||
|
||
notify_error(c::Channel) = notify_error(c, c.excp) | ||
|
||
eltype(::Type{Channel{T}}) where {T} = T | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be
@static if
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they're equivalent at the top-level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know, thanks.
There are a few places further down that are not at top-level that should be
@static if
then,I've now marked them