From c720d7365afcc0631978c51d922f650910a4d2c4 Mon Sep 17 00:00:00 2001 From: Amit Murthy Date: Wed, 18 Jan 2017 13:27:18 +0530 Subject: [PATCH] remove checks for acc types. [ci skip] --- base/multi.jl | 54 +++++++++++++++++++-------------------------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/base/multi.jl b/base/multi.jl index e61db7a09b506..349b8598a9e2f 100644 --- a/base/multi.jl +++ b/base/multi.jl @@ -2083,13 +2083,13 @@ end function pfor(f, R) lenR = length(R) chunks = splitrange(lenR, workers()) - accums = get(task_local_storage(), :JULIA_ACCUMULATOR, ()) - if accums !== () - accums = accums[1] - accums = isa(accums, ParallelAccumulator) ? [accums] : accums - for acc in accums + tls_acc = get(task_local_storage(), :JULIA_ACCUMULATOR, ()) + if tls_acc !== () + acc_current = tls_acc[1] + acc_coll = isa(acc_current, ParallelAccumulator) ? [acc_current] : acc_current + for acc in acc_coll lenR != acc.length && throw(AssertionError("loop length must equal ParallelAccumulator length")) - set_destf(acc, p->length(chunks[p])) + set_f_len_at_pid!(acc, p->length(chunks[p])) end end @@ -2162,9 +2162,13 @@ type ParallelAccumulator{T} value::Nullable{T} # A function which returns a length value when input the destination pid. - # Used to serialize the same object with different length values depending - # on the destination pid. - destf::Nullable{Function} + # Each worker processes a subset of a paralle for-loop. During serialization + # f_len_at_pid is called to retrieve the length of the range that needs to be + # processed at pid. On the remote node, we write the locally accumulated value + # to the remote channel once len_at_pid values are processed. + # On the destination node, this field will be NULL and is used to loosely differentiate + # between the original instance on the caller and the deserialized instances on the workers. + f_len_at_pid::Nullable{Function} chnl::RemoteChannel @@ -2179,21 +2183,21 @@ type ParallelAccumulator{T} ParallelAccumulator(f, len, initial, chnl) = ParallelAccumulator{T}(f, len, initial, Nullable{Function}(), chnl) - ParallelAccumulator(f, len, initial, destf, chnl) = - new(f, len, len, initial, initial, destf, chnl) + ParallelAccumulator(f, len, initial, f_len_at_pid, chnl) = + new(f, len, len, initial, initial, f_len_at_pid, chnl) end -set_destf(pacc::ParallelAccumulator, f::Function) = (pacc.destf = f; pacc) +set_f_len_at_pid!(pacc::ParallelAccumulator, f::Function) = (pacc.f_len_at_pid = f; pacc) function serialize(s::AbstractSerializer, pacc::ParallelAccumulator) serialize_cycle(s, pacc) && return serialize_type(s, typeof(pacc)) - if isnull(pacc.destf) + if isnull(pacc.f_len_at_pid) error("Cannot serialize a ParallelAccumulator from a destination node.") end - len = get(pacc.destf)(worker_id_from_socket(s.io)) + len = get(pacc.f_len_at_pid)(worker_id_from_socket(s.io)) serialize(s, pacc.f) serialize(s, len) @@ -2214,7 +2218,7 @@ end function push!(pacc::ParallelAccumulator, v) if pacc.pending <= 0 - throw(AssertionError("Reusing a ParallelAccumulator is not allowed. reset(p::ParallelAccumulator)?")) + throw(AssertionError("Reusing a ParallelAccumulator is not allowed. reset(acc)?")) end if !isnull(pacc.value) @@ -2246,31 +2250,13 @@ end function reset(pacc::ParallelAccumulator) pacc.pending = pacc.length pacc.value = pacc.initial - pacc.destf = Nullable{Function}() + pacc.f_len_at_pid = Nullable{Function}() pacc end macro accumulate(acc, expr) - if !(isa(acc, Symbol) || (isa(acc, Expr) && acc.head == :vect)) - throw(ArgumentError(string( - "@accumulate : ", - "First argument must be a variable name pointing to a ParallelAccumulator ", - "or a vector of variable names pointing to ParallelAccumulators. ", - "Found : ", typeof(acc)))) - end - quote esc_acc = $(esc(acc)) - if !(isa(esc_acc, ParallelAccumulator) || - isa(esc_acc, Array{ParallelAccumulator}) || - (isa(esc_acc, Array) && all(x->isa(x, ParallelAccumulator), esc_acc))) - - throw(ArgumentError(string( - "@accumulate : First argument must be a ParallelAccumulator ", - "or a vector of ParallelAccumulators. ", - "Found : ", typeof(esc_acc)))) - - end old_list = get(task_local_storage(), :JULIA_ACCUMULATOR, ()) task_local_storage(:JULIA_ACCUMULATOR, ($(esc(acc)), old_list))