Skip to content

Commit

Permalink
ReusableGoroutinesPool: Fix datarace on Close (#607)
Browse files Browse the repository at this point in the history
  • Loading branch information
julienduchesne authored Oct 11, 2024
1 parent 8e7752e commit cf9b0bc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
26 changes: 23 additions & 3 deletions concurrency/worker.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package concurrency

import (
"sync"
)

// NewReusableGoroutinesPool creates a new worker pool with the given size.
// These workers will run the workloads passed through Go() calls.
// If all workers are busy, Go() will spawn a new goroutine to run the workload.
Expand All @@ -18,12 +22,23 @@ func NewReusableGoroutinesPool(size int) *ReusableGoroutinesPool {
}

type ReusableGoroutinesPool struct {
jobs chan func()
jobsMu sync.RWMutex
closed bool
jobs chan func()
}

// Go will run the given function in a worker of the pool.
// If all workers are busy, Go() will spawn a new goroutine to run the workload.
func (p *ReusableGoroutinesPool) Go(f func()) {
p.jobsMu.RLock()
defer p.jobsMu.RUnlock()

// If the pool is closed, run the function in a new goroutine.
if p.closed {
go f()
return
}

select {
case p.jobs <- f:
default:
Expand All @@ -32,7 +47,12 @@ func (p *ReusableGoroutinesPool) Go(f func()) {
}

// Close stops the workers of the pool.
// No new Do() calls should be performed after calling Close().
// No new Go() calls should be performed after calling Close().
// Close does NOT wait for all jobs to finish, it is the caller's responsibility to ensure that in the provided workloads.
// Close is intended to be used in tests to ensure that no goroutines are leaked.
func (p *ReusableGoroutinesPool) Close() { close(p.jobs) }
func (p *ReusableGoroutinesPool) Close() {
p.jobsMu.Lock()
defer p.jobsMu.Unlock()
p.closed = true
close(p.jobs)
}
28 changes: 28 additions & 0 deletions concurrency/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"regexp"
"runtime"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)

func TestReusableGoroutinesPool(t *testing.T) {
Expand Down Expand Up @@ -59,3 +61,29 @@ func TestReusableGoroutinesPool(t *testing.T) {
}
t.Fatalf("expected %d goroutines after closing, got %d", 0, countGoroutines())
}

// TestReusableGoroutinesPool_Race tests that Close() and Go() can be called concurrently.
func TestReusableGoroutinesPool_Race(t *testing.T) {
w := NewReusableGoroutinesPool(2)

var runCountAtomic atomic.Int32
const maxMsgCount = 10

var testWG sync.WaitGroup
testWG.Add(1)
go func() {
defer testWG.Done()
for i := 0; i < maxMsgCount; i++ {
w.Go(func() {
runCountAtomic.Add(1)
})
time.Sleep(10 * time.Millisecond)
}
}()
time.Sleep(10 * time.Millisecond)
w.Close() // close the pool
testWG.Wait() // wait for the test to finish

runCt := int(runCountAtomic.Load())
require.Equal(t, runCt, 10, "expected all functions to run")
}

0 comments on commit cf9b0bc

Please sign in to comment.