Skip to content

Commit

Permalink
fix: simplify concurrent queue using one less mutex
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Stewart <christian@aperture.us>
  • Loading branch information
paralin committed Jul 24, 2024
1 parent 599a8f3 commit 6cc51d1
Showing 1 changed file with 46 additions and 46 deletions.
92 changes: 46 additions & 46 deletions conc/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package conc

import (
"context"
"sync"

"github.com/aperturerobotics/util/broadcast"
"github.com/aperturerobotics/util/linkedlist"
Expand All @@ -11,9 +10,7 @@ import (
// ConcurrentQueue is a pool of goroutines processing a stream of jobs.
// Job callbacks are called in the order they are added.
type ConcurrentQueue struct {
// mtx guards below fields
mtx sync.Mutex
// bcast is broadcasted when fields change
// bcast guards below fields
bcast broadcast.Broadcast
// maxConcurrency is the concurrency limit or 0 if none
maxConcurrency int
Expand All @@ -35,9 +32,9 @@ func NewConcurrentQueue(maxConcurrency int, initialElems ...func()) *ConcurrentQ
maxConcurrency: maxConcurrency,
}
if len(initialElems) != 0 {
str.mtx.Lock()
str.updateLocked()
str.mtx.Unlock()
str.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
str.updateLocked(broadcast)
})
}
return str
}
Expand All @@ -46,39 +43,39 @@ func NewConcurrentQueue(maxConcurrency int, initialElems ...func()) *ConcurrentQ
// If possible, the job is started immediately and skips the queue.
// Returns the current number of queued and running jobs.
func (s *ConcurrentQueue) Enqueue(jobs ...func()) (queued, running int) {
s.mtx.Lock()
defer s.mtx.Unlock()

if len(jobs) == 0 {
return s.jobQueueSize, s.running
}

for _, job := range jobs {
if s.maxConcurrency <= 0 || s.running < s.maxConcurrency {
s.running++
go s.executeJob(job)
} else {
s.jobQueueSize++
s.jobQueue.Push(job)
s.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
if len(jobs) != 0 {
for _, job := range jobs {
if s.maxConcurrency <= 0 || s.running < s.maxConcurrency {
s.running++
go s.executeJob(job)
} else {
s.jobQueueSize++
s.jobQueue.Push(job)
}
}
broadcast()
}
}

s.bcast.Broadcast()
return s.jobQueueSize, s.running
queued, running = s.jobQueueSize, s.running
})

return queued, running
}

// WaitIdle waits for no jobs to be running.
// Returns context.Canceled if ctx is canceled.
// errCh is an optional error channel.
func (s *ConcurrentQueue) WaitIdle(ctx context.Context, errCh <-chan error) error {
var wait <-chan struct{}
for {
s.mtx.Lock()
idle := s.running == 0 && s.jobQueueSize == 0
if !idle {
wait = s.bcast.GetWaitCh()
}
s.mtx.Unlock()
var idle bool
var wait <-chan struct{}
s.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
idle = s.running == 0 && s.jobQueueSize == 0
if !idle {
wait = getWaitCh()
}
})
if idle {
return nil
}
Expand Down Expand Up @@ -114,10 +111,12 @@ func (s *ConcurrentQueue) WatchState(
}

for {
s.mtx.Lock()
queued, running := s.jobQueueSize, s.running
waitCh := s.bcast.GetWaitCh()
s.mtx.Unlock()
var queued, running int
var waitCh <-chan struct{}
s.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
queued, running = s.jobQueueSize, s.running
waitCh = getWaitCh()
})

cntu, err := cb(queued, running)
if err != nil || !cntu {
Expand All @@ -134,7 +133,7 @@ func (s *ConcurrentQueue) WatchState(

// updateLocked checks if we need to spawn any new routines.
// caller must hold mtx
func (s *ConcurrentQueue) updateLocked() {
func (s *ConcurrentQueue) updateLocked(broadcast func()) {
var dirty bool
for s.maxConcurrency <= 0 || s.running < s.maxConcurrency {
job, jobOk := s.jobQueue.Pop()
Expand All @@ -147,7 +146,7 @@ func (s *ConcurrentQueue) updateLocked() {
go s.executeJob(job)
}
if dirty {
s.bcast.Broadcast()
broadcast()
}
}

Expand All @@ -158,16 +157,17 @@ func (s *ConcurrentQueue) executeJob(job func()) {
if job != nil {
job()
}
s.mtx.Lock()

var jobOk bool
job, jobOk = s.jobQueue.Pop()
if !jobOk {
s.running--
s.bcast.Broadcast()
} else {
s.jobQueueSize--
}
s.mtx.Unlock()
s.bcast.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
job, jobOk = s.jobQueue.Pop()
if !jobOk {
s.running--
broadcast()
} else {
s.jobQueueSize--
}
})
if !jobOk {
return
}
Expand Down

0 comments on commit 6cc51d1

Please sign in to comment.