diff --git a/perf/mselect/bucket.go b/perf/mselect/bucket.go index 8efb821..eb66f25 100644 --- a/perf/mselect/bucket.go +++ b/perf/mselect/bucket.go @@ -26,7 +26,8 @@ type taskBucket struct { cases []linkname.RuntimeSelect tasks []*Task - block bool + block bool + stopped bool } func newTaskBucket(msel *manySelect, userTask *Task) *taskBucket { @@ -40,10 +41,10 @@ func newTaskBucket(msel *manySelect, userTask *Task) *taskBucket { sigTask := msel.sigTask stopTask := NewTask(b.m.stop, nil, nil) delTask := NewTask(b.delCh, nil, nil) - b.addTask(sigTask, false) - b.addTask(stopTask, false) - b.addTask(delTask, false) - b.addTask(userTask, true) + b.addTask(sigTask) + b.addTask(stopTask) + b.addTask(delTask) + b.addTask(userTask) go b.loop() return b } @@ -62,40 +63,31 @@ func (b *taskBucket) loop() { // Got a signal or a new task submitted. switch i { case 0: // signal - b.processSignal(recv, ok) - if !ok { - return // stopped + if !ok { // stopped + b.stop() + return } + b.processSignal(recv) case 1: // stopped b.stop() return case 2: // delete task delTask := *(**Task)(recv) - if delTask.bIdx < 0 || delTask.tIdx < sigTaskNum { - panic("mselect: invalid task to delete") - } - // The task's channel may already been closed, - // then the task would be automatically removed. - // We check tIdx to make sure the task is still in the task list. - if delTask.tIdx < len(b.tasks) { - b.removeTask(delTask.tIdx) - b.m.decrCount(1) - } + b.deleteTask(delTask) default: - b.processTask(i, task, recv, ok) + b.executeTask(task, recv, ok) } } } -func (b *taskBucket) processSignal(recv unsafe.Pointer, ok bool) { - if !ok { // stopped - b.stop() +func (b *taskBucket) processSignal(recv unsafe.Pointer) { + if b.stopped { return } // Add a new task. newTask := *(**Task)(recv) - b.addTask(newTask, true) + b.addTask(newTask) // If the bucket is full, block the signal channel to avoid // accepting new tasks. @@ -106,7 +98,7 @@ func (b *taskBucket) processSignal(recv unsafe.Pointer, ok bool) { } } -func (b *taskBucket) processTask(i int, task *Task, recv unsafe.Pointer, ok bool) { +func (b *taskBucket) executeTask(task *Task, recv unsafe.Pointer, ok bool) { // Execute the task first. // When the channel is closed, call callback functions with // a zero value and ok = false. @@ -116,8 +108,7 @@ func (b *taskBucket) processTask(i int, task *Task, recv unsafe.Pointer, ok bool // Delete task if the channel was closed. if !ok { - b.removeTask(i) - b.m.decrCount(1) + b.deleteTask(task) // Reset signal task to accept new tasks. if b.block && len(b.cases) < bucketSize { @@ -130,15 +121,45 @@ func (b *taskBucket) processTask(i int, task *Task, recv unsafe.Pointer, ok bool } } -func (b *taskBucket) addTask(task *Task, setIndex bool) { - if setIndex { - task.bIdx = b.idx - task.tIdx = len(b.tasks) +func (b *taskBucket) addTask(task *Task) { + task.mu.Lock() + defer task.mu.Unlock() + + if task.deleted { + return + } + if task.added { + panic("mselect: task already added") } + + // The fields added, bucket, tIdx are used to coordinate with deletion, + // don't write these fields for signal tasks. + if idx := len(b.tasks); idx >= sigTaskNum { + task.added = true + task.bucket = b + task.tIdx = idx + } + b.tasks = append(b.tasks, task) b.cases = append(b.cases, task.newRuntimeSelect()) } +func (b *taskBucket) deleteTask(task *Task) { + task.mu.Lock() + defer task.mu.Unlock() + + if !task.added { + return + } + if task.bucket == nil || task.tIdx < sigTaskNum { + panic("mselect: invalid task to delete") + } + + task.deleted = true + b.removeTask(task.tIdx) + b.m.decrCount(1) +} + func (b *taskBucket) removeTask(i int) { n := len(b.cases) if n > sigTaskNum+1 && i < n-1 { @@ -153,10 +174,15 @@ func (b *taskBucket) removeTask(i int) { } func (b *taskBucket) stop() { + if b.stopped { + return + } + n := len(b.tasks) - sigTaskNum // don't count the signal tasks b.m.decrCount(n) b.cases = nil b.tasks = nil + b.stopped = true // Drain tasks to un-block any goroutines blocked by sending tasks // to b.m.tasks. diff --git a/perf/mselect/mselect.go b/perf/mselect/mselect.go index 807e6a1..9cfac3e 100644 --- a/perf/mselect/mselect.go +++ b/perf/mselect/mselect.go @@ -61,9 +61,6 @@ func (p *manySelect) Add(task *Task) { if atomic.LoadInt32(&p.stopped) > 0 { return } - if !atomic.CompareAndSwapInt32(&task.added, 0, 1) { - panic("mselect: adding task more than once") - } p.mu.Lock() if atomic.AddInt32(&p.count, 1) < p.cap() { @@ -78,15 +75,7 @@ func (p *manySelect) Add(task *Task) { } func (p *manySelect) Delete(task *Task) { - if atomic.LoadInt32(&task.added) == 0 { - panic("mselect: the task is not added") - } - if !atomic.CompareAndSwapInt32(&task.deleted, 0, 1) { - panic("mselect: deleting task more than once") - } - p.mu.Lock() - p.buckets[task.bIdx].signalDelete(task) - p.mu.Unlock() + task.signalDelete() } func (p *manySelect) Count() int { diff --git a/perf/mselect/task.go b/perf/mselect/task.go index 8612312..2c46e7e 100644 --- a/perf/mselect/task.go +++ b/perf/mselect/task.go @@ -2,6 +2,7 @@ package mselect import ( "reflect" + "sync" "unsafe" "github.com/jxskiss/gopkg/v2/internal/linkname" @@ -26,12 +27,29 @@ func NewTask[T any]( newFunc: func() unsafe.Pointer { return unsafe.Pointer(new(T)) }, - bIdx: -1, tIdx: -1, } return task } +func buildTaskFunc[T any]( + syncCallback func(v T, ok bool), + asyncCallback func(v T, ok bool), +) func(v unsafe.Pointer, ok bool) { + if syncCallback == nil && asyncCallback == nil { + return nil + } + return func(v unsafe.Pointer, ok bool) { + tVal := *(*T)(v) + if syncCallback != nil { + syncCallback(tVal, ok) + } + if asyncCallback != nil { + go asyncCallback(tVal, ok) + } + } +} + // Task is a channel receiving task which can be submitted to ManySelect. // A zero Task is not ready to use, use NewTask to create a Task. // @@ -42,11 +60,11 @@ type Task struct { execFunc func(v unsafe.Pointer, ok bool) newFunc func() unsafe.Pointer - bIdx int // bucket index - tIdx int // task index - - added int32 - deleted int32 + mu sync.Mutex + bucket *taskBucket + tIdx int // task index in bucket + added bool + deleted bool } func (t *Task) newRuntimeSelect() linkname.RuntimeSelect { @@ -66,20 +84,18 @@ func (t *Task) getAndResetRecvValue(rsel *linkname.RuntimeSelect) unsafe.Pointer return recv } -func buildTaskFunc[T any]( - syncCallback func(v T, ok bool), - asyncCallback func(v T, ok bool), -) func(v unsafe.Pointer, ok bool) { - if syncCallback == nil && asyncCallback == nil { - return nil +func (t *Task) signalDelete() { + t.mu.Lock() + if t.deleted { + t.mu.Unlock() + return } - return func(v unsafe.Pointer, ok bool) { - tVal := *(*T)(v) - if syncCallback != nil { - syncCallback(tVal, ok) - } - if asyncCallback != nil { - go asyncCallback(tVal, ok) - } + t.deleted = true + bucket := t.bucket + t.mu.Unlock() + + // No need to hold lock to send the signal. + if bucket != nil { + bucket.signalDelete(t) } }