From 44b1bcc3b6c6380bc4dbb4eb570d35153f0f3304 Mon Sep 17 00:00:00 2001 From: hongkuan Date: Sun, 1 Sep 2024 10:41:37 +0800 Subject: [PATCH] drain tasks in an unblocked way to avoid data race --- pond.go | 5 +++-- pond_blackbox_test.go | 22 ++++++++++++++++++++++ worker.go | 11 +++++++++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/pond.go b/pond.go index 4635dc6..8d29feb 100644 --- a/pond.go +++ b/pond.go @@ -370,13 +370,14 @@ func (p *WorkerPool) stop(waitForQueuedTasksToComplete bool) { // Terminate all workers & purger goroutine p.contextCancel() + // Wait for all workers & purger goroutine to exit + p.workersWaitGroup.Wait() + // close tasks channel (only once, in case multiple concurrent calls to StopAndWait are made) p.tasksCloseOnce.Do(func() { close(p.tasks) }) - // Wait for all workers & purger goroutine to exit - p.workersWaitGroup.Wait() } // purge represents the work done by the purger goroutine diff --git a/pond_blackbox_test.go b/pond_blackbox_test.go index b011f3c..d7c7465 100644 --- a/pond_blackbox_test.go +++ b/pond_blackbox_test.go @@ -406,6 +406,28 @@ func TestPoolWithCustomIdleTimeout(t *testing.T) { pool.StopAndWait() } +func TestStopWithPurging(t *testing.T) { + + pool := pond.New(5, 5, pond.IdleTimeout(100*time.Millisecond)) + + // Submit a task + for i := 0; i < 5; i++ { + pool.Submit(func() { + time.Sleep(10 * time.Millisecond) + }) + } + + assertEqual(t, 5, pool.RunningWorkers()) + + // Purge goroutine is clearing idle workers + time.Sleep(200 * time.Millisecond) + + // Stop the pool to make sure there is no data race with purge goroutine + pool.StopAndWait() + + assertEqual(t, 0, pool.RunningWorkers()) +} + func TestPoolWithCustomPanicHandler(t *testing.T) { var capturedPanic interface{} = nil diff --git a/worker.go b/worker.go index c312bde..c403d15 100644 --- a/worker.go +++ b/worker.go @@ -46,7 +46,14 @@ func worker(context context.Context, waitGroup *sync.WaitGroup, firstTask func() // drainPendingTasks discards queued tasks and decrements the corresponding wait group func drainTasks(tasks <-chan func(), tasksWaitGroup *sync.WaitGroup) { - for _ = range tasks { - tasksWaitGroup.Done() + for { + select { + case task, ok := <-tasks: + if task != nil && ok { + tasksWaitGroup.Done() + } + default: + return + } } }