diff --git a/internal/workerpool/collector.go b/internal/workerpool/collector.go new file mode 100644 index 0000000000..4c5257afda --- /dev/null +++ b/internal/workerpool/collector.go @@ -0,0 +1,34 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package workerpool + +type collector[R any] struct { + resultsCh <-chan R + results []R + + errorsCh <-chan error + errors []error + + requestsRead int64 +} + +func (c *collector[R]) Start() { + go func() { + for { + select { + case result := <-c.resultsCh: + c.results = append(c.results, result) + + case err := <-c.errorsCh: + c.errors = append(c.errors, err) + } + + c.requestsRead++ + } + }() +} + +func (c *collector[R]) RequestsRead() int64 { + return c.requestsRead +} diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index 3bf8d71bfc..d407543cff 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -4,68 +4,114 @@ package workerpool import ( + "sync/atomic" + "time" + "github.com/cockroachdb/errors" ) type Task[R any] func() (result R, err error) +// Pool is a generic pool of workers. type Pool[R any] struct { - queue chan Task[R] - results chan R - errors chan error - workerCount int - requestsSent int - requestsRead int - - err error + queueCh chan Task[R] + resultsCh chan R + errorsCh chan error + requestsSent int64 + + workers []*worker[R] + workerCount int + + collector[R] } +// New initializes a new Pool with the provided number of workers. The pool is generic and can +// accept any type of Task that returns the signature `func() (R, error)`. +// +// For example, a Pool[int] will accept Tasks similar to: +// +// task := func() (int, error) { +// return 42, nil +// } func New[R any](count int) *Pool[R] { + resultsCh := make(chan R) + errorsCh := make(chan error) return &Pool[R]{ - queue: make(chan Task[R]), - results: make(chan R), - errors: make(chan error), + queueCh: make(chan Task[R]), + resultsCh: resultsCh, + errorsCh: errorsCh, workerCount: count, + collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh}, } } +// Start the pool workers and collector. Make sure call `Close()` to clear the pool. +// +// pool := workerpool.New[int](10) +// pool.Start() +// defer pool.Close() func (p *Pool[R]) Start() { for i := 0; i < p.workerCount; i++ { - w := worker[R]{id: i, queue: p.queue, results: p.results, errors: p.errors} + w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} w.Start() + p.workers = append(p.workers, &w) } - p.errorCollector() + p.collector.Start() } -func (p *Pool[R]) errorCollector() { - go func() { - for e := range p.errors { - p.err = errors.Join(p.err, e) - } - }() +// Submit sends a task to the workers +func (p *Pool[R]) Submit(t Task[R]) { + p.queueCh <- t + atomic.AddInt64(&p.requestsSent, 1) } -func (p *Pool[R]) GetError() error { - return p.err +// GetErrors returns any error from a processed task +func (p *Pool[R]) GetErrors() error { + return errors.Join(p.collector.errors...) } -func (p *Pool[R]) Submit(t Task[R]) { - p.queue <- t - p.requestsSent++ +// GetResults returns the tasks results. +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GetResults() []R { + return p.collector.results +} + +// Close waits for workers and collector to process all the requests, and then closes +// the task queue channel. After closing the pool, calling `Submit()` will panic. +func (p *Pool[R]) Close() { + p.Wait() + close(p.queueCh) } -func (p *Pool[R]) GetResult() R { - defer func() { - p.requestsRead++ - }() - return <-p.results +// Wait waits until all tasks have been processed. +func (p *Pool[R]) Wait() { + ticker := time.NewTicker(100 * time.Millisecond) + for { + if !p.Processing() { + return + } + <-ticker.C + } } -func (p *Pool[R]) HasPendingRequests() bool { - return p.requestsSent-p.requestsRead > 0 +// PendingRequests returns the number of pending requests. +func (p *Pool[R]) PendingRequests() int64 { + return p.requestsSent - p.collector.RequestsRead() } -func (p *Pool[R]) Close() { - close(p.queue) +// Processing return true if tasks are being processed. +func (p *Pool[R]) Processing() bool { + if !p.empty() { + return false + } + + return p.PendingRequests() != 0 +} + +func (p *Pool[R]) empty() bool { + return len(p.queueCh) == 0 && + len(p.resultsCh) == 0 && + len(p.errorsCh) == 0 } diff --git a/internal/workerpool/pool_test.go b/internal/workerpool/pool_test.go index 6337ca28c0..3b3946df1e 100644 --- a/internal/workerpool/pool_test.go +++ b/internal/workerpool/pool_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "math/rand" + "github.com/stretchr/testify/assert" "go.mondoo.com/cnquery/v11/internal/workerpool" ) @@ -21,24 +23,23 @@ func TestPoolSubmitAndRetrieveResult(t *testing.T) { return 42, nil } - // no requests - assert.False(t, pool.HasPendingRequests()) + // no results + assert.Empty(t, pool.GetResults()) // submit a request pool.Submit(task) - // should have pending requests - assert.True(t, pool.HasPendingRequests()) - - // assert results comes back - result := pool.GetResult() - assert.Equal(t, 42, result) + // wait for the request to process + pool.Wait() - // no more requests pending - assert.False(t, pool.HasPendingRequests()) + // should have one result + results := pool.GetResults() + if assert.Len(t, results, 1) { + assert.Equal(t, 42, results[0]) + } // no errors - assert.Nil(t, pool.GetError()) + assert.Nil(t, pool.GetErrors()) } func TestPoolHandleErrors(t *testing.T) { @@ -53,9 +54,9 @@ func TestPoolHandleErrors(t *testing.T) { pool.Submit(task) // Wait for error collector to process - time.Sleep(100 * time.Millisecond) + pool.Wait() - err := pool.GetError() + err := pool.GetErrors() if assert.Error(t, err) { assert.Contains(t, err.Error(), "task error") } @@ -82,14 +83,15 @@ func TestPoolMultipleTasksWithErrors(t *testing.T) { pool.Submit(task) } - var results []*test - for range tasks { - results = append(results, pool.GetResult()) - } - - assert.ElementsMatch(t, []*test{nil, &test{1}, &test{2}, &test{3}}, results) - assert.False(t, pool.HasPendingRequests()) + // Wait for error collector to process + pool.Wait() + results := pool.GetResults() + assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results) + err := pool.GetErrors() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "task error") + } } func TestPoolHandlesNilTasks(t *testing.T) { @@ -100,14 +102,13 @@ func TestPoolHandlesNilTasks(t *testing.T) { var nilTask workerpool.Task[int] pool.Submit(nilTask) - // Wait for worker to process the nil task - time.Sleep(100 * time.Millisecond) + pool.Wait() - err := pool.GetError() + err := pool.GetErrors() assert.NoError(t, err) } -func TestPoolHasPendingRequests(t *testing.T) { +func TestPoolProcessing(t *testing.T) { pool := workerpool.New[int](2) pool.Start() defer pool.Close() @@ -118,11 +119,19 @@ func TestPoolHasPendingRequests(t *testing.T) { } pool.Submit(task) - assert.True(t, pool.HasPendingRequests()) - result := pool.GetResult() - assert.Equal(t, 10, result) - assert.False(t, pool.HasPendingRequests()) + // should be processing + assert.True(t, pool.Processing()) + + // wait + pool.Wait() + + // read results + result := pool.GetResults() + assert.Equal(t, []int{10}, result) + + // should not longer be processing + assert.False(t, pool.Processing()) } func TestPoolClosesGracefully(t *testing.T) { @@ -143,3 +152,34 @@ func TestPoolClosesGracefully(t *testing.T) { pool.Submit(task) }) } + +func TestPoolWithManyTasks(t *testing.T) { + // 30k requests with a pool of 100 workers + // should be around 15 seconds + requestCount := 30000 + pool := workerpool.New[int](100) + pool.Start() + defer pool.Close() + + task := func() (int, error) { + random := rand.Intn(100) + time.Sleep(time.Duration(random) * time.Millisecond) + return random, nil + } + + for i := 0; i < requestCount; i++ { + pool.Submit(task) + } + + // should be processing + assert.True(t, pool.Processing()) + + // wait + pool.Wait() + + // read results + assert.Equal(t, requestCount, len(pool.GetResults())) + + // should not longer be processing + assert.False(t, pool.Processing()) +} diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go index 4a391d44b6..19b21de1e7 100644 --- a/internal/workerpool/worker.go +++ b/internal/workerpool/worker.go @@ -4,25 +4,27 @@ package workerpool type worker[R any] struct { - id int - queue <-chan Task[R] - results chan<- R - errors chan<- error + id int + queueCh <-chan Task[R] + resultsCh chan<- R + errorsCh chan<- error } func (w *worker[R]) Start() { go func() { - for task := range w.queue { + for task := range w.queueCh { if task == nil { + // let the collector know we processed the request + w.errorsCh <- nil continue } data, err := task() if err != nil { - w.errors <- err + w.errorsCh <- err + } else { + w.resultsCh <- data } - - w.results <- data } }() } diff --git a/providers/github/resources/github_org.go b/providers/github/resources/github_org.go index ad783c57c8..ef39f97159 100644 --- a/providers/github/resources/github_org.go +++ b/providers/github/resources/github_org.go @@ -5,6 +5,7 @@ package resources import ( "errors" + "slices" "strconv" "strings" "time" @@ -284,47 +285,30 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { Str("organization", g.Name.Data). Msg("list repositories") - var allRepos []*github.Repository for { - // exit as soon as we collect all repositories - if len(allRepos) >= int(repoCount) { + reposLen := len(slices.Concat(workerPool.GetResults()...)) + if reposLen >= int(repoCount) { break } - // send as many request as workers we have - for i := 1; i <= workers; i++ { - opts := listOpts - workerPool.Submit(func() ([]*github.Repository, error) { - repos, _, err := conn.Client().Repositories.ListByOrg(conn.Context(), orgLogin, &opts) - return repos, err - }) - - // check if we need to submit more requests - newRepoCount := len(allRepos) + i*paginationPerPage - if newRepoCount > int(repoCount) { - break - } + // send requests to workers + opts := listOpts + workerPool.Submit(func() ([]*github.Repository, error) { + repos, _, err := conn.Client().Repositories.ListByOrg(conn.Context(), orgLogin, &opts) + return repos, err + }) - // next page - listOpts.Page++ - } - - // wait for the results - for i := 0; i < workers; i++ { - if workerPool.HasPendingRequests() { - allRepos = append(allRepos, workerPool.GetResult()...) - } - } + // next page + listOpts.Page++ // check if any request failed - if err := workerPool.GetError(); err != nil { + if err := workerPool.GetErrors(); err != nil { if strings.Contains(err.Error(), "404") { return nil, nil } return nil, err } - } if g.repoCacheMap == nil { @@ -332,15 +316,17 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { } res := []interface{}{} - for i := range allRepos { - repo := allRepos[i] + for _, repos := range workerPool.GetResults() { + for i := range repos { + repo := repos[i] - r, err := newMqlGithubRepository(g.MqlRuntime, repo) - if err != nil { - return nil, err + r, err := newMqlGithubRepository(g.MqlRuntime, repo) + if err != nil { + return nil, err + } + res = append(res, r) + g.repoCacheMap[repo.GetName()] = r } - res = append(res, r) - g.repoCacheMap[repo.GetName()] = r } return res, nil