Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ fetch org repositories in parallel #4970

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions internal/workerpool/collector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

type collector[R any] struct {
resultsCh <-chan R
results []R

errorsCh <-chan error
errors []error

requestsRead int64
}

func (c *collector[R]) Start() {
go func() {
for {
select {
case result := <-c.resultsCh:
c.results = append(c.results, result)

case err := <-c.errorsCh:
c.errors = append(c.errors, err)
}

c.requestsRead++
}
}()
}

func (c *collector[R]) RequestsRead() int64 {
return c.requestsRead
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there might be a data race here on requestsRead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Thanks for pointing it out, Jay.

}
117 changes: 117 additions & 0 deletions internal/workerpool/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

import (
"sync/atomic"
"time"

"github.com/cockroachdb/errors"
)

type Task[R any] func() (result R, err error)

// Pool is a generic pool of workers.
type Pool[R any] struct {
queueCh chan Task[R]
resultsCh chan R
errorsCh chan error
requestsSent int64

workers []*worker[R]
workerCount int

collector[R]
}

// New initializes a new Pool with the provided number of workers. The pool is generic and can
// accept any type of Task that returns the signature `func() (R, error)`.
//
// For example, a Pool[int] will accept Tasks similar to:
//
// task := func() (int, error) {
// return 42, nil
// }
func New[R any](count int) *Pool[R] {
resultsCh := make(chan R)
errorsCh := make(chan error)
return &Pool[R]{
queueCh: make(chan Task[R]),
resultsCh: resultsCh,
errorsCh: errorsCh,
workerCount: count,
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh},
}
}

// Start the pool workers and collector. Make sure call `Close()` to clear the pool.
//
// pool := workerpool.New[int](10)
// pool.Start()
// defer pool.Close()
func (p *Pool[R]) Start() {
for i := 0; i < p.workerCount; i++ {
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh}
w.Start()
p.workers = append(p.workers, &w)
}

p.collector.Start()
}

// Submit sends a task to the workers
func (p *Pool[R]) Submit(t Task[R]) {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}

// GetErrors returns any error from a processed task
func (p *Pool[R]) GetErrors() error {
return errors.Join(p.collector.errors...)
}

// GetResults returns the tasks results.
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetResults() []R {
return p.collector.results
}

// Close waits for workers and collector to process all the requests, and then closes
// the task queue channel. After closing the pool, calling `Submit()` will panic.
func (p *Pool[R]) Close() {
p.Wait()
close(p.queueCh)
}

// Wait waits until all tasks have been processed.
func (p *Pool[R]) Wait() {
ticker := time.NewTicker(100 * time.Millisecond)
for {
if !p.Processing() {
return
}
<-ticker.C
}
}

// PendingRequests returns the number of pending requests.
func (p *Pool[R]) PendingRequests() int64 {
return p.requestsSent - p.collector.RequestsRead()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not really familiar with the atomics api in go, but i would expect some sort of synchronized read on requestsSent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atomic helps to do thread safe operations, yes, I missed the read calls, they are now there! 🙌🏽

}

// Processing return true if tasks are being processed.
func (p *Pool[R]) Processing() bool {
if !p.empty() {
return false
}

return p.PendingRequests() != 0
}

func (p *Pool[R]) empty() bool {
return len(p.queueCh) == 0 &&
len(p.resultsCh) == 0 &&
len(p.errorsCh) == 0
}
185 changes: 185 additions & 0 deletions internal/workerpool/pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool_test

import (
"errors"
"testing"
"time"

"math/rand"

"github.com/stretchr/testify/assert"
"go.mondoo.com/cnquery/v11/internal/workerpool"
)

func TestPoolSubmitAndRetrieveResult(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

task := func() (int, error) {
return 42, nil
}

// no results
assert.Empty(t, pool.GetResults())

// submit a request
pool.Submit(task)

// wait for the request to process
pool.Wait()

// should have one result
results := pool.GetResults()
if assert.Len(t, results, 1) {
assert.Equal(t, 42, results[0])
}

// no errors
assert.Nil(t, pool.GetErrors())
}

func TestPoolHandleErrors(t *testing.T) {
pool := workerpool.New[int](5)
pool.Start()
defer pool.Close()

// submit a task that will return an error
task := func() (int, error) {
return 0, errors.New("task error")
}
pool.Submit(task)

// Wait for error collector to process
pool.Wait()

err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
}

func TestPoolMultipleTasksWithErrors(t *testing.T) {
type test struct {
data int
}
pool := workerpool.New[*test](5)
pool.Start()
defer pool.Close()

tasks := []workerpool.Task[*test]{
func() (*test, error) { return &test{1}, nil },
func() (*test, error) { return &test{2}, nil },
func() (*test, error) {
return nil, errors.New("task error")
},
func() (*test, error) { return &test{3}, nil },
}

for _, task := range tasks {
pool.Submit(task)
}

// Wait for error collector to process
pool.Wait()

results := pool.GetResults()
assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results)
err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
}

func TestPoolHandlesNilTasks(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

var nilTask workerpool.Task[int]
pool.Submit(nilTask)

pool.Wait()

err := pool.GetErrors()
assert.NoError(t, err)
}

func TestPoolProcessing(t *testing.T) {
pool := workerpool.New[int](2)
pool.Start()
defer pool.Close()

task := func() (int, error) {
time.Sleep(50 * time.Millisecond)
return 10, nil
}

pool.Submit(task)

// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
result := pool.GetResults()
assert.Equal(t, []int{10}, result)

// should not longer be processing
assert.False(t, pool.Processing())
}

func TestPoolClosesGracefully(t *testing.T) {
pool := workerpool.New[int](1)
pool.Start()

task := func() (int, error) {
time.Sleep(100 * time.Millisecond)
return 42, nil
}

pool.Submit(task)

pool.Close()

// Ensure no panic occurs and channels are closed
assert.PanicsWithError(t, "send on closed channel", func() {
pool.Submit(task)
})
}

func TestPoolWithManyTasks(t *testing.T) {
// 30k requests with a pool of 100 workers
// should be around 15 seconds
requestCount := 30000
pool := workerpool.New[int](100)
pool.Start()
defer pool.Close()

task := func() (int, error) {
random := rand.Intn(100)
time.Sleep(time.Duration(random) * time.Millisecond)
return random, nil
}

for i := 0; i < requestCount; i++ {
pool.Submit(task)
}

// should be processing
assert.True(t, pool.Processing())

// wait
pool.Wait()

// read results
assert.Equal(t, requestCount, len(pool.GetResults()))

// should not longer be processing
assert.False(t, pool.Processing())
}
30 changes: 30 additions & 0 deletions internal/workerpool/worker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package workerpool

type worker[R any] struct {
id int
queueCh <-chan Task[R]
resultsCh chan<- R
errorsCh chan<- error
}

func (w *worker[R]) Start() {
go func() {
for task := range w.queueCh {
if task == nil {
// let the collector know we processed the request
w.errorsCh <- nil
continue
}

data, err := task()
if err != nil {
w.errorsCh <- err
} else {
w.resultsCh <- data
}
}
}()
}
4 changes: 2 additions & 2 deletions providers-sdk/v1/inventory/inventory.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions providers-sdk/v1/plugin/plugin.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading