Skip to content

Commit

Permalink
NFC: create an actual set of functions to manipulate GC thread ids (J…
Browse files Browse the repository at this point in the history
…uliaLang#54984)

Also adds a bunch of integrity constraint checks to ensure we don't
repeat the bug from JuliaLang#54645.
  • Loading branch information
d-netto committed Jul 4, 2024
1 parent 5163fe7 commit 57ca8df
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 24 deletions.
52 changes: 34 additions & 18 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1639,9 +1639,11 @@ void gc_sweep_wake_all(jl_ptls_t ptls, jl_gc_padded_page_stack_t *new_gc_allocd_
if (parallel_sweep_worthwhile && !page_profile_enabled) {
jl_atomic_store(&gc_allocd_scratch, new_gc_allocd_scratch);
uv_mutex_lock(&gc_threads_lock);
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
int first = gc_first_parallel_collector_thread_id();
int last = gc_last_parallel_collector_thread_id();
for (int i = first; i <= last; i++) {
jl_ptls_t ptls2 = gc_all_tls_states[i];
assert(ptls2 != NULL); // should be a GC thread
gc_check_ptls_of_parallel_collector_thread(ptls2);
jl_atomic_fetch_add(&ptls2->gc_sweeps_requested, 1);
}
uv_cond_broadcast(&gc_threads_cond);
Expand All @@ -1653,9 +1655,11 @@ void gc_sweep_wake_all(jl_ptls_t ptls, jl_gc_padded_page_stack_t *new_gc_allocd_
// collecting a page profile.
// wait for all to leave in order to ensure that a straggler doesn't
// try to enter sweeping after we set `gc_allocd_scratch` below.
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
int first = gc_first_parallel_collector_thread_id();
int last = gc_last_parallel_collector_thread_id();
for (int i = first; i <= last; i++) {
jl_ptls_t ptls2 = gc_all_tls_states[i];
assert(ptls2 != NULL); // should be a GC thread
gc_check_ptls_of_parallel_collector_thread(ptls2);
while (jl_atomic_load_acquire(&ptls2->gc_sweeps_requested) != 0) {
jl_cpu_pause();
}
Expand Down Expand Up @@ -3006,19 +3010,25 @@ void gc_mark_and_steal(jl_ptls_t ptls)
// since we know chunks will likely expand into a lot
// of work for the mark loop
steal : {
int first = gc_first_parallel_collector_thread_id();
int last = gc_last_parallel_collector_thread_id();
// Try to steal chunk from random GC thread
for (int i = 0; i < 4 * jl_n_markthreads; i++) {
uint32_t v = gc_first_tid + cong(UINT64_MAX, UINT64_MAX, &ptls->rngseed) % jl_n_markthreads;
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[v]->mark_queue;
int v = gc_random_parallel_collector_thread_id(ptls);
jl_ptls_t ptls2 = gc_all_tls_states[v];
gc_check_ptls_of_parallel_collector_thread(ptls2);
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
c = gc_chunkqueue_steal_from(mq2);
if (c.cid != GC_empty_chunk) {
gc_mark_chunk(ptls, mq, &c);
goto pop;
}
}
// Sequentially walk GC threads to try to steal chunk
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[i]->mark_queue;
for (int i = first; i <= last; i++) {
jl_ptls_t ptls2 = gc_all_tls_states[i];
gc_check_ptls_of_parallel_collector_thread(ptls2);
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
c = gc_chunkqueue_steal_from(mq2);
if (c.cid != GC_empty_chunk) {
gc_mark_chunk(ptls, mq, &c);
Expand All @@ -3035,15 +3045,19 @@ void gc_mark_and_steal(jl_ptls_t ptls)
}
// Try to steal pointer from random GC thread
for (int i = 0; i < 4 * jl_n_markthreads; i++) {
uint32_t v = gc_first_tid + cong(UINT64_MAX, UINT64_MAX, &ptls->rngseed) % jl_n_markthreads;
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[v]->mark_queue;
int v = gc_random_parallel_collector_thread_id(ptls);
jl_ptls_t ptls2 = gc_all_tls_states[v];
gc_check_ptls_of_parallel_collector_thread(ptls2);
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
new_obj = gc_ptr_queue_steal_from(mq2);
if (new_obj != NULL)
goto mark;
}
// Sequentially walk GC threads to try to steal pointer
for (int i = gc_first_tid; i < gc_first_tid + jl_n_markthreads; i++) {
jl_gc_markqueue_t *mq2 = &gc_all_tls_states[i]->mark_queue;
for (int i = first; i <= last; i++) {
jl_ptls_t ptls2 = gc_all_tls_states[i];
gc_check_ptls_of_parallel_collector_thread(ptls2);
jl_gc_markqueue_t *mq2 = &ptls2->mark_queue;
new_obj = gc_ptr_queue_steal_from(mq2);
if (new_obj != NULL)
goto mark;
Expand Down Expand Up @@ -3103,12 +3117,13 @@ int gc_should_mark(void)
}
int tid = jl_atomic_load_relaxed(&gc_master_tid);
assert(tid != -1);
assert(gc_all_tls_states != NULL);
size_t work = gc_count_work_in_queue(gc_all_tls_states[tid]);
for (tid = gc_first_tid; tid < gc_first_tid + jl_n_markthreads; tid++) {
jl_ptls_t ptls2 = gc_all_tls_states[tid];
if (ptls2 == NULL) {
continue;
}
int first = gc_first_parallel_collector_thread_id();
int last = gc_last_parallel_collector_thread_id();
for (int i = first; i <= last; i++) {
jl_ptls_t ptls2 = gc_all_tls_states[i];
gc_check_ptls_of_parallel_collector_thread(ptls2);
work += gc_count_work_in_queue(ptls2);
}
// if there is a lot of work left, enter the mark loop
Expand Down Expand Up @@ -3486,7 +3501,8 @@ static int _jl_gc_collect(jl_ptls_t ptls, jl_gc_collection_t collection)
jl_ptls_t ptls_dest = ptls;
jl_gc_markqueue_t *mq_dest = mq;
if (!single_threaded_mark) {
ptls_dest = gc_all_tls_states[gc_first_tid + t_i % jl_n_markthreads];
int dest_tid = gc_ith_parallel_collector_thread_id(t_i % jl_n_markthreads);
ptls_dest = gc_all_tls_states[dest_tid];
mq_dest = &ptls_dest->mark_queue;
}
if (ptls2 != NULL) {
Expand Down
48 changes: 48 additions & 0 deletions src/gc.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,54 @@ extern int gc_first_tid;
extern int gc_n_threads;
extern jl_ptls_t* gc_all_tls_states;

STATIC_INLINE int gc_first_parallel_collector_thread_id(void) JL_NOTSAFEPOINT
{
if (jl_n_markthreads == 0) {
return 0;
}
return gc_first_tid;
}

STATIC_INLINE int gc_last_parallel_collector_thread_id(void) JL_NOTSAFEPOINT
{
if (jl_n_markthreads == 0) {
return -1;
}
return gc_first_tid + jl_n_markthreads - 1;
}

STATIC_INLINE int gc_ith_parallel_collector_thread_id(int i) JL_NOTSAFEPOINT
{
assert(i >= 0 && i < jl_n_markthreads);
return gc_first_tid + i;
}

STATIC_INLINE int gc_is_parallel_collector_thread(int tid) JL_NOTSAFEPOINT
{
return tid >= gc_first_tid && tid <= gc_last_parallel_collector_thread_id();
}

STATIC_INLINE int gc_random_parallel_collector_thread_id(jl_ptls_t ptls) JL_NOTSAFEPOINT
{
assert(jl_n_markthreads > 0);
int v = gc_first_tid + (int)cong(jl_n_markthreads - 1, UINT64_MAX, &ptls->rngseed);
assert(v >= gc_first_tid && v <= gc_last_parallel_collector_thread_id());
return v;
}

STATIC_INLINE int gc_parallel_collector_threads_enabled(void) JL_NOTSAFEPOINT
{
return jl_n_markthreads > 0;
}

STATIC_INLINE void gc_check_ptls_of_parallel_collector_thread(jl_ptls_t ptls) JL_NOTSAFEPOINT
{
(void)ptls;
assert(gc_parallel_collector_threads_enabled());
assert(ptls != NULL);
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_PARALLEL_COLLECTOR_THREAD);
}

STATIC_INLINE bigval_t *bigval_header(jl_taggedvalue_t *o) JL_NOTSAFEPOINT
{
return container_of(o, bigval_t, header);
Expand Down
6 changes: 6 additions & 0 deletions src/julia_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ typedef struct _jl_tls_states_t {
#define JL_GC_STATE_SAFE 2
// gc_state = 2 means the thread is running unmanaged code that can be
// execute at the same time with the GC.
#define JL_GC_PARALLEL_COLLECTOR_THREAD 3
// gc_state = 3 means the thread is a parallel collector thread (i.e. never runs Julia code)
#define JL_GC_CONCURRENT_COLLECTOR_THREAD 4
// gc_state = 4 means the thread is a concurrent collector thread (background sweeper thread that never runs Julia code)
_Atomic(int8_t) gc_state; // read from foreign threads
// execution of certain certain impure
// statements is prohibited from certain
Expand Down Expand Up @@ -343,6 +347,8 @@ void jl_sigint_safepoint(jl_ptls_t tls);
STATIC_INLINE int8_t jl_gc_state_set(jl_ptls_t ptls, int8_t state,
int8_t old_state)
{
assert(old_state != JL_GC_PARALLEL_COLLECTOR_THREAD);
assert(old_state != JL_GC_CONCURRENT_COLLECTOR_THREAD);
jl_atomic_store_release(&ptls->gc_state, state);
// A safe point is required if we transition from GC-safe region to
// non GC-safe region.
Expand Down
13 changes: 7 additions & 6 deletions src/partr.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void jl_parallel_gc_threadfun(void *arg)
jl_task_t *ct = jl_init_root_task(ptls, stack_lo, stack_hi);
JL_GC_PROMISE_ROOTED(ct);
// wait for all threads
jl_gc_state_set(ptls, JL_GC_STATE_WAITING, 0);
jl_gc_state_set(ptls, JL_GC_PARALLEL_COLLECTOR_THREAD, 0);
uv_barrier_wait(targ->barrier);

// free the thread argument here
Expand All @@ -143,10 +143,10 @@ void jl_parallel_gc_threadfun(void *arg)
uv_cond_wait(&gc_threads_cond, &gc_threads_lock);
}
uv_mutex_unlock(&gc_threads_lock);
if (may_mark()) {
gc_mark_loop_parallel(ptls, 0);
}
if (may_sweep(ptls)) { // not an else!
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_PARALLEL_COLLECTOR_THREAD);
gc_mark_loop_parallel(ptls, 0);
if (may_sweep(ptls)) {
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_PARALLEL_COLLECTOR_THREAD);
gc_sweep_pool_parallel(ptls);
jl_atomic_fetch_add(&ptls->gc_sweeps_requested, -1);
}
Expand All @@ -166,13 +166,14 @@ void jl_concurrent_gc_threadfun(void *arg)
jl_task_t *ct = jl_init_root_task(ptls, stack_lo, stack_hi);
JL_GC_PROMISE_ROOTED(ct);
// wait for all threads
jl_gc_state_set(ptls, JL_GC_STATE_WAITING, 0);
jl_gc_state_set(ptls, JL_GC_CONCURRENT_COLLECTOR_THREAD, 0);
uv_barrier_wait(targ->barrier);

// free the thread argument here
free(targ);

while (1) {
assert(jl_atomic_load_relaxed(&ptls->gc_state) == JL_GC_CONCURRENT_COLLECTOR_THREAD);
uv_sem_wait(&gc_sweep_assists_needed);
gc_free_pages();
}
Expand Down

0 comments on commit 57ca8df

Please sign in to comment.