Skip to content

Commit

Permalink
remove checks for acc types. [ci skip]
Browse files Browse the repository at this point in the history
  • Loading branch information
amitmurthy committed Jan 18, 2017
1 parent 2f21f3a commit c720d73
Showing 1 changed file with 20 additions and 34 deletions.
54 changes: 20 additions & 34 deletions base/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2083,13 +2083,13 @@ end
function pfor(f, R)
lenR = length(R)
chunks = splitrange(lenR, workers())
accums = get(task_local_storage(), :JULIA_ACCUMULATOR, ())
if accums !== ()
accums = accums[1]
accums = isa(accums, ParallelAccumulator) ? [accums] : accums
for acc in accums
tls_acc = get(task_local_storage(), :JULIA_ACCUMULATOR, ())
if tls_acc !== ()
acc_current = tls_acc[1]
acc_coll = isa(acc_current, ParallelAccumulator) ? [acc_current] : acc_current
for acc in acc_coll
lenR != acc.length && throw(AssertionError("loop length must equal ParallelAccumulator length"))
set_destf(acc, p->length(chunks[p]))
set_f_len_at_pid!(acc, p->length(chunks[p]))
end
end

Expand Down Expand Up @@ -2162,9 +2162,13 @@ type ParallelAccumulator{T}
value::Nullable{T}

# A function which returns a length value when input the destination pid.
# Used to serialize the same object with different length values depending
# on the destination pid.
destf::Nullable{Function}
# Each worker processes a subset of a paralle for-loop. During serialization
# f_len_at_pid is called to retrieve the length of the range that needs to be
# processed at pid. On the remote node, we write the locally accumulated value
# to the remote channel once len_at_pid values are processed.
# On the destination node, this field will be NULL and is used to loosely differentiate
# between the original instance on the caller and the deserialized instances on the workers.
f_len_at_pid::Nullable{Function}

chnl::RemoteChannel

Expand All @@ -2179,21 +2183,21 @@ type ParallelAccumulator{T}
ParallelAccumulator(f, len, initial, chnl) =
ParallelAccumulator{T}(f, len, initial, Nullable{Function}(), chnl)

ParallelAccumulator(f, len, initial, destf, chnl) =
new(f, len, len, initial, initial, destf, chnl)
ParallelAccumulator(f, len, initial, f_len_at_pid, chnl) =
new(f, len, len, initial, initial, f_len_at_pid, chnl)
end

set_destf(pacc::ParallelAccumulator, f::Function) = (pacc.destf = f; pacc)
set_f_len_at_pid!(pacc::ParallelAccumulator, f::Function) = (pacc.f_len_at_pid = f; pacc)

function serialize(s::AbstractSerializer, pacc::ParallelAccumulator)
serialize_cycle(s, pacc) && return
serialize_type(s, typeof(pacc))

if isnull(pacc.destf)
if isnull(pacc.f_len_at_pid)
error("Cannot serialize a ParallelAccumulator from a destination node.")
end

len = get(pacc.destf)(worker_id_from_socket(s.io))
len = get(pacc.f_len_at_pid)(worker_id_from_socket(s.io))

serialize(s, pacc.f)
serialize(s, len)
Expand All @@ -2214,7 +2218,7 @@ end

function push!(pacc::ParallelAccumulator, v)
if pacc.pending <= 0
throw(AssertionError("Reusing a ParallelAccumulator is not allowed. reset(p::ParallelAccumulator)?"))
throw(AssertionError("Reusing a ParallelAccumulator is not allowed. reset(acc)?"))
end

if !isnull(pacc.value)
Expand Down Expand Up @@ -2246,31 +2250,13 @@ end
function reset(pacc::ParallelAccumulator)
pacc.pending = pacc.length
pacc.value = pacc.initial
pacc.destf = Nullable{Function}()
pacc.f_len_at_pid = Nullable{Function}()
pacc
end

macro accumulate(acc, expr)
if !(isa(acc, Symbol) || (isa(acc, Expr) && acc.head == :vect))
throw(ArgumentError(string(
"@accumulate : ",
"First argument must be a variable name pointing to a ParallelAccumulator ",
"or a vector of variable names pointing to ParallelAccumulators. ",
"Found : ", typeof(acc))))
end

quote
esc_acc = $(esc(acc))
if !(isa(esc_acc, ParallelAccumulator) ||
isa(esc_acc, Array{ParallelAccumulator}) ||
(isa(esc_acc, Array) && all(x->isa(x, ParallelAccumulator), esc_acc)))

throw(ArgumentError(string(
"@accumulate : First argument must be a ParallelAccumulator ",
"or a vector of ParallelAccumulators. ",
"Found : ", typeof(esc_acc))))

end

old_list = get(task_local_storage(), :JULIA_ACCUMULATOR, ())
task_local_storage(:JULIA_ACCUMULATOR, ($(esc(acc)), old_list))
Expand Down

0 comments on commit c720d73

Please sign in to comment.