Skip to content

Commit

Permalink
Band aid to make threaded loop a little easier to work with
Browse files Browse the repository at this point in the history
* Print a warning if an error occurs in the threaded loop (Helps #17532)
* Make recursive threaded loops "work" (Fix #18335).

  The proper fix will be tracked by #21017
  • Loading branch information
yuyichao committed Apr 20, 2017
1 parent 8b2ee71 commit 914933e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.
29 changes: 23 additions & 6 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@ on `threadid()`.
"""
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))

# Only read/write by t he main thread
const in_threaded_loop = Ref(false)

function _threadsfor(iter,lbody)
fun = gensym("_threadsfor")
lidx = iter.args[1] # index
range = iter.args[2]
quote
function $fun()
tid = threadid()
function threadsfor_fun(onethread=false)
r = $(esc(range))
lenr = length(r)
# divide loop iterations among threads
len, rem = divrem(length(r), nthreads())
if onethread
tid = 1
len, rem = lenr, 0
else
tid = threadid()
len, rem = divrem(lenr, nthreads())
end
# not enough iterations for all the threads?
if len == 0
if tid > rem
Expand All @@ -54,7 +62,17 @@ function _threadsfor(iter,lbody)
$(esc(lbody))
end
end
ccall(:jl_threading_run, Ref{Void}, (Any,), $fun)
# Hack to make nested threaded loops kinda work
if threadid() != 1 || in_threaded_loop[]
# We are in a nested threaded loop
threadsfor_fun(true)
else
in_threaded_loop[] = true
# the ccall is not expected to throw
ccall(:jl_threading_run, Ref{Void}, (Any,), threadsfor_fun)
in_threaded_loop[] = false
end
nothing
end
end
"""
Expand All @@ -80,4 +98,3 @@ macro threads(args...)
throw(ArgumentError("unrecognized argument to @threads"))
end
end

15 changes: 14 additions & 1 deletion src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,20 @@ static jl_value_t *ti_run_fun(const jl_generic_fptr_t *fptr, jl_method_instance_
jl_call_fptr_internal(fptr, mfunc, args, nargs);
}
JL_CATCH {
return ptls->exception_in_transit;
// Lock this output since we know it'll likely happen on multiple threads
static jl_mutex_t lock;
JL_LOCK_NOGC(&lock);
jl_jmp_buf *old_buf = ptls->safe_restore;
jl_jmp_buf buf;
if (!jl_setjmp(buf, 0)) {
// Set up the safe_restore context so that the printing uses the thread safe version
ptls->safe_restore = &buf;
jl_printf(JL_STDERR, "\nError thrown in threaded loop on thread %d: ",
(int)ptls->tid);
jl_static_show(JL_STDERR, ptls->exception_in_transit);
}
ptls->safe_restore = old_buf;
JL_UNLOCK_NOGC(&lock);
}
return jl_nothing;
}
Expand Down
17 changes: 17 additions & 0 deletions test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,20 @@ function test_load_and_lookup_18020(n)
end
end
test_load_and_lookup_18020(10000)

# Nested threaded loops
# This may not be efficient/fully supported but should work without crashing.....
function test_nested_loops()
a = zeros(Int, 100, 100)
@threads for i in 1:100
@threads for j in 1:100
a[j, i] = i + j
end
end
for i in 1:100
for j in 1:100
@test a[j, i] == i + j
end
end
end
test_nested_loops()

0 comments on commit 914933e

Please sign in to comment.