Skip to content

Commit

Permalink
⚙️ add a collector to the workerpool
Browse files Browse the repository at this point in the history
This will help us submit as many requests as we want without knowing
about the workers.

Signed-off-by: Salim Afiune Maya <afiune@mondoo.com>
  • Loading branch information
afiune committed Dec 11, 2024
1 parent 5ae3d75 commit 93ca08d
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 104 deletions.
34 changes: 34 additions & 0 deletions internal/workerpool/collector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

type collector[R any] struct {
resultsCh <-chan R
results []R

errorsCh <-chan error
errors []error

requestsRead int64
}

func (c *collector[R]) Start() {
go func() {
for {
select {
case result := <-c.resultsCh:
c.results = append(c.results, result)

case err := <-c.errorsCh:
c.errors = append(c.errors, err)
}

c.requestsRead++
}
}()
}

func (c *collector[R]) RequestsRead() int64 {
return c.requestsRead
}
112 changes: 79 additions & 33 deletions internal/workerpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,114 @@
package workerpool

import (
"sync/atomic"
"time"

"github.com/cockroachdb/errors"
)

type Task[R any] func() (result R, err error)

// Pool is a generic pool of workers.
type Pool[R any] struct {
queue chan Task[R]
results chan R
errors chan error
workerCount int
requestsSent int
requestsRead int

err error
queueCh chan Task[R]
resultsCh chan R
errorsCh chan error
requestsSent int64

workers []*worker[R]
workerCount int

collector[R]
}

// New initializes a new Pool with the provided number of workers. The pool is generic and can
// accept any type of Task that returns the signature `func() (R, error)`.
//
// For example, a Pool[int] will accept Tasks similar to:
//
// task := func() (int, error) {
// return 42, nil
// }
func New[R any](count int) *Pool[R] {
resultsCh := make(chan R)
errorsCh := make(chan error)
return &Pool[R]{
queue: make(chan Task[R]),
results: make(chan R),
errors: make(chan error),
queueCh: make(chan Task[R]),
resultsCh: resultsCh,
errorsCh: errorsCh,
workerCount: count,
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh},
}
}

// Start the pool workers and collector. Make sure call `Close()` to clear the pool.
//
// pool := workerpool.New[int](10)
// pool.Start()
// defer pool.Close()
func (p *Pool[R]) Start() {
for i := 0; i < p.workerCount; i++ {
w := worker[R]{id: i, queue: p.queue, results: p.results, errors: p.errors}
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh}
w.Start()
p.workers = append(p.workers, &w)
}

p.errorCollector()
p.collector.Start()
}

func (p *Pool[R]) errorCollector() {
go func() {
for e := range p.errors {
p.err = errors.Join(p.err, e)
}
}()
// Submit sends a task to the workers
func (p *Pool[R]) Submit(t Task[R]) {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}

func (p *Pool[R]) GetError() error {
return p.err
// GetErrors returns any error from a processed task
func (p *Pool[R]) GetErrors() error {
return errors.Join(p.collector.errors...)
}

func (p *Pool[R]) Submit(t Task[R]) {
p.queue <- t
p.requestsSent++
// GetResults returns the tasks results.
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetResults() []R {
return p.collector.results
}

// Close waits for workers and collector to process all the requests, and then closes
// the task queue channel. After closing the pool, calling `Submit()` will panic.
func (p *Pool[R]) Close() {
p.Wait()
close(p.queueCh)
}

func (p *Pool[R]) GetResult() R {
defer func() {
p.requestsRead++
}()
return <-p.results
// Wait waits until all tasks have been processed.
func (p *Pool[R]) Wait() {
ticker := time.NewTicker(100 * time.Millisecond)
for {
if !p.Processing() {
return
}
<-ticker.C
}
}

func (p *Pool[R]) HasPendingRequests() bool {
return p.requestsSent-p.requestsRead > 0
// PendingRequests returns the number of pending requests.
func (p *Pool[R]) PendingRequests() int64 {
return p.requestsSent - p.collector.RequestsRead()
}

func (p *Pool[R]) Close() {
close(p.queue)
// Processing return true if tasks are being processed.
func (p *Pool[R]) Processing() bool {
if !p.empty() {
return false
}

return p.PendingRequests() != 0
}

func (p *Pool[R]) empty() bool {
return len(p.queueCh) == 0 &&
len(p.resultsCh) == 0 &&
len(p.errorsCh) == 0
}
96 changes: 68 additions & 28 deletions internal/workerpool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"testing"
"time"

"math/rand"

"github.com/stretchr/testify/assert"
"go.mondoo.com/cnquery/v11/internal/workerpool"
)
Expand All @@ -21,24 +23,23 @@ func TestPoolSubmitAndRetrieveResult(t *testing.T) {
return 42, nil
}

// no requests
assert.False(t, pool.HasPendingRequests())
// no results
assert.Empty(t, pool.GetResults())

// submit a request
pool.Submit(task)

// should have pending requests
assert.True(t, pool.HasPendingRequests())

// assert results comes back
result := pool.GetResult()
assert.Equal(t, 42, result)
// wait for the request to process
pool.Wait()

// no more requests pending
assert.False(t, pool.HasPendingRequests())
// should have one result
results := pool.GetResults()
if assert.Len(t, results, 1) {
assert.Equal(t, 42, results[0])
}

// no errors
assert.Nil(t, pool.GetError())
assert.Nil(t, pool.GetErrors())
}

func TestPoolHandleErrors(t *testing.T) {
Expand All @@ -53,9 +54,9 @@ func TestPoolHandleErrors(t *testing.T) {
pool.Submit(task)

// Wait for error collector to process
time.Sleep(100 * time.Millisecond)
pool.Wait()

err := pool.GetError()
err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
Expand All @@ -82,14 +83,15 @@ func TestPoolMultipleTasksWithErrors(t *testing.T) {
pool.Submit(task)
}

var results []*test
for range tasks {
results = append(results, pool.GetResult())
}

assert.ElementsMatch(t, []*test{nil, &test{1}, &test{2}, &test{3}}, results)
assert.False(t, pool.HasPendingRequests())
// Wait for error collector to process
pool.Wait()

results := pool.GetResults()
assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results)
err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
}

func TestPoolHandlesNilTasks(t *testing.T) {
Expand All @@ -100,14 +102,13 @@ func TestPoolHandlesNilTasks(t *testing.T) {
var nilTask workerpool.Task[int]
pool.Submit(nilTask)

// Wait for worker to process the nil task
time.Sleep(100 * time.Millisecond)
pool.Wait()

err := pool.GetError()
err := pool.GetErrors()
assert.NoError(t, err)
}

func TestPoolHasPendingRequests(t *testing.T) {
func TestPoolProcessing(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()
Expand All @@ -118,11 +119,19 @@ func TestPoolHasPendingRequests(t *testing.T) {
}

pool.Submit(task)
assert.True(t, pool.HasPendingRequests())

result := pool.GetResult()
assert.Equal(t, 10, result)
assert.False(t, pool.HasPendingRequests())
// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
result := pool.GetResults()
assert.Equal(t, []int{10}, result)

// should not longer be processing
assert.False(t, pool.Processing())
}

func TestPoolClosesGracefully(t *testing.T) {
Expand All @@ -143,3 +152,34 @@ func TestPoolClosesGracefully(t *testing.T) {
pool.Submit(task)
})
}

func TestPoolWithManyTasks(t *testing.T) {
// 30k requests with a pool of 100 workers
// should be around 15 seconds
requestCount := 30000
pool := workerpool.New[int](100)
pool.Start()
defer pool.Close()

task := func() (int, error) {
random := rand.Intn(100)
time.Sleep(time.Duration(random) * time.Millisecond)
return random, nil
}

for i := 0; i < requestCount; i++ {
pool.Submit(task)
}

// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
assert.Equal(t, requestCount, len(pool.GetResults()))

// should not longer be processing
assert.False(t, pool.Processing())
}
18 changes: 10 additions & 8 deletions internal/workerpool/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
package workerpool

type worker[R any] struct {
id int
queue <-chan Task[R]
results chan<- R
errors chan<- error
id int
queueCh <-chan Task[R]
resultsCh chan<- R
errorsCh chan<- error
}

func (w *worker[R]) Start() {
go func() {
for task := range w.queue {
for task := range w.queueCh {
if task == nil {
// let the collector know we processed the request
w.errorsCh <- nil
continue
}

data, err := task()
if err != nil {
w.errors <- err
w.errorsCh <- err
} else {
w.resultsCh <- data
}

w.results <- data
}
}()
}
Loading

0 comments on commit 93ca08d

Please sign in to comment.