Skip to content

Commit

Permalink
[mselect] fix data race issues
Browse files Browse the repository at this point in the history
Change-Id: Ifad938444be3daae29df226e1c78c69465b49b41
  • Loading branch information
jxskiss committed Nov 30, 2023
1 parent 0942ea0 commit b2775ec
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 62 deletions.
86 changes: 56 additions & 30 deletions perf/mselect/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ type taskBucket struct {
cases []linkname.RuntimeSelect
tasks []*Task

block bool
block bool
stopped bool
}

func newTaskBucket(msel *manySelect, userTask *Task) *taskBucket {
Expand All @@ -40,10 +41,10 @@ func newTaskBucket(msel *manySelect, userTask *Task) *taskBucket {
sigTask := msel.sigTask
stopTask := NewTask(b.m.stop, nil, nil)
delTask := NewTask(b.delCh, nil, nil)
b.addTask(sigTask, false)
b.addTask(stopTask, false)
b.addTask(delTask, false)
b.addTask(userTask, true)
b.addTask(sigTask)
b.addTask(stopTask)
b.addTask(delTask)
b.addTask(userTask)
go b.loop()
return b
}
Expand All @@ -62,40 +63,31 @@ func (b *taskBucket) loop() {
// Got a signal or a new task submitted.
switch i {
case 0: // signal
b.processSignal(recv, ok)
if !ok {
return // stopped
if !ok { // stopped
b.stop()
return
}
b.processSignal(recv)
case 1: // stopped
b.stop()
return
case 2: // delete task
delTask := *(**Task)(recv)
if delTask.bIdx < 0 || delTask.tIdx < sigTaskNum {
panic("mselect: invalid task to delete")
}
// The task's channel may already been closed,
// then the task would be automatically removed.
// We check tIdx to make sure the task is still in the task list.
if delTask.tIdx < len(b.tasks) {
b.removeTask(delTask.tIdx)
b.m.decrCount(1)
}
b.deleteTask(delTask)
default:
b.processTask(i, task, recv, ok)
b.executeTask(task, recv, ok)
}
}
}

func (b *taskBucket) processSignal(recv unsafe.Pointer, ok bool) {
if !ok { // stopped
b.stop()
func (b *taskBucket) processSignal(recv unsafe.Pointer) {
if b.stopped {
return
}

// Add a new task.
newTask := *(**Task)(recv)
b.addTask(newTask, true)
b.addTask(newTask)

// If the bucket is full, block the signal channel to avoid
// accepting new tasks.
Expand All @@ -106,7 +98,7 @@ func (b *taskBucket) processSignal(recv unsafe.Pointer, ok bool) {
}
}

func (b *taskBucket) processTask(i int, task *Task, recv unsafe.Pointer, ok bool) {
func (b *taskBucket) executeTask(task *Task, recv unsafe.Pointer, ok bool) {
// Execute the task first.
// When the channel is closed, call callback functions with
// a zero value and ok = false.
Expand All @@ -116,8 +108,7 @@ func (b *taskBucket) processTask(i int, task *Task, recv unsafe.Pointer, ok bool

// Delete task if the channel was closed.
if !ok {
b.removeTask(i)
b.m.decrCount(1)
b.deleteTask(task)

// Reset signal task to accept new tasks.
if b.block && len(b.cases) < bucketSize {
Expand All @@ -130,15 +121,45 @@ func (b *taskBucket) processTask(i int, task *Task, recv unsafe.Pointer, ok bool
}
}

func (b *taskBucket) addTask(task *Task, setIndex bool) {
if setIndex {
task.bIdx = b.idx
task.tIdx = len(b.tasks)
func (b *taskBucket) addTask(task *Task) {
task.mu.Lock()
defer task.mu.Unlock()

if task.deleted {
return
}
if task.added {
panic("mselect: task already added")
}

// The fields added, bucket, tIdx are used to coordinate with deletion,
// don't write these fields for signal tasks.
if idx := len(b.tasks); idx >= sigTaskNum {
task.added = true
task.bucket = b
task.tIdx = idx
}

b.tasks = append(b.tasks, task)
b.cases = append(b.cases, task.newRuntimeSelect())
}

func (b *taskBucket) deleteTask(task *Task) {
task.mu.Lock()
defer task.mu.Unlock()

if !task.added {
return
}
if task.bucket == nil || task.tIdx < sigTaskNum {
panic("mselect: invalid task to delete")
}

task.deleted = true
b.removeTask(task.tIdx)
b.m.decrCount(1)
}

func (b *taskBucket) removeTask(i int) {
n := len(b.cases)
if n > sigTaskNum+1 && i < n-1 {
Expand All @@ -153,10 +174,15 @@ func (b *taskBucket) removeTask(i int) {
}

func (b *taskBucket) stop() {
if b.stopped {
return
}

n := len(b.tasks) - sigTaskNum // don't count the signal tasks
b.m.decrCount(n)
b.cases = nil
b.tasks = nil
b.stopped = true

// Drain tasks to un-block any goroutines blocked by sending tasks
// to b.m.tasks.
Expand Down
13 changes: 1 addition & 12 deletions perf/mselect/mselect.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ func (p *manySelect) Add(task *Task) {
if atomic.LoadInt32(&p.stopped) > 0 {
return
}
if !atomic.CompareAndSwapInt32(&task.added, 0, 1) {
panic("mselect: adding task more than once")
}

p.mu.Lock()
if atomic.AddInt32(&p.count, 1) < p.cap() {
Expand All @@ -78,15 +75,7 @@ func (p *manySelect) Add(task *Task) {
}

func (p *manySelect) Delete(task *Task) {
if atomic.LoadInt32(&task.added) == 0 {
panic("mselect: the task is not added")
}
if !atomic.CompareAndSwapInt32(&task.deleted, 0, 1) {
panic("mselect: deleting task more than once")
}
p.mu.Lock()
p.buckets[task.bIdx].signalDelete(task)
p.mu.Unlock()
task.signalDelete()
}

func (p *manySelect) Count() int {
Expand Down
56 changes: 36 additions & 20 deletions perf/mselect/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mselect

import (
"reflect"
"sync"
"unsafe"

"github.com/jxskiss/gopkg/v2/internal/linkname"
Expand All @@ -26,12 +27,29 @@ func NewTask[T any](
newFunc: func() unsafe.Pointer {
return unsafe.Pointer(new(T))
},
bIdx: -1,
tIdx: -1,
}
return task
}

func buildTaskFunc[T any](
syncCallback func(v T, ok bool),
asyncCallback func(v T, ok bool),
) func(v unsafe.Pointer, ok bool) {
if syncCallback == nil && asyncCallback == nil {
return nil
}
return func(v unsafe.Pointer, ok bool) {
tVal := *(*T)(v)
if syncCallback != nil {
syncCallback(tVal, ok)
}
if asyncCallback != nil {
go asyncCallback(tVal, ok)
}
}
}

// Task is a channel receiving task which can be submitted to ManySelect.
// A zero Task is not ready to use, use NewTask to create a Task.
//
Expand All @@ -42,11 +60,11 @@ type Task struct {
execFunc func(v unsafe.Pointer, ok bool)
newFunc func() unsafe.Pointer

bIdx int // bucket index
tIdx int // task index

added int32
deleted int32
mu sync.Mutex
bucket *taskBucket
tIdx int // task index in bucket
added bool
deleted bool
}

func (t *Task) newRuntimeSelect() linkname.RuntimeSelect {
Expand All @@ -66,20 +84,18 @@ func (t *Task) getAndResetRecvValue(rsel *linkname.RuntimeSelect) unsafe.Pointer
return recv
}

func buildTaskFunc[T any](
syncCallback func(v T, ok bool),
asyncCallback func(v T, ok bool),
) func(v unsafe.Pointer, ok bool) {
if syncCallback == nil && asyncCallback == nil {
return nil
func (t *Task) signalDelete() {
t.mu.Lock()
if t.deleted {
t.mu.Unlock()
return
}
return func(v unsafe.Pointer, ok bool) {
tVal := *(*T)(v)
if syncCallback != nil {
syncCallback(tVal, ok)
}
if asyncCallback != nil {
go asyncCallback(tVal, ok)
}
t.deleted = true
bucket := t.bucket
t.mu.Unlock()

// No need to hold lock to send the signal.
if bucket != nil {
bucket.signalDelete(t)
}
}

0 comments on commit b2775ec

Please sign in to comment.