Skip to content

Commit

Permalink
resourcemanager: fix TaskController.Stop() can't make producer exit i…
Browse files Browse the repository at this point in the history
…n spmcpool (#41016)

close #41015
  • Loading branch information
hawkingrei authored Feb 3, 2023
1 parent eb11fb2 commit 3c4976b
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 30 deletions.
5 changes: 4 additions & 1 deletion resourcemanager/pooltask/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ go_library(
],
importpath = "github.com/pingcap/tidb/resourcemanager/pooltask",
visibility = ["//visibility:public"],
deps = ["@org_uber_go_atomic//:atomic"],
deps = [
"//util/channel",
"@org_uber_go_atomic//:atomic",
],
)

go_test(
Expand Down
35 changes: 23 additions & 12 deletions resourcemanager/pooltask/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package pooltask
import (
"sync"
"sync/atomic"

"github.com/pingcap/tidb/util/channel"
)

// Context is a interface that can be used to create a context.
Expand Down Expand Up @@ -127,35 +129,44 @@ type GPool[T any, U any, C any, CT any, TF Context[CT]] interface {

// TaskController is a controller that can control or watch the pool.
type TaskController[T any, U any, C any, CT any, TF Context[CT]] struct {
pool GPool[T, U, C, CT, TF]
close chan struct{}
wg *sync.WaitGroup
taskID uint64
resultCh chan U
pool GPool[T, U, C, CT, TF]
productExitCh chan struct{}
wg *sync.WaitGroup
taskID uint64
resultCh chan U
inputCh chan Task[T]
}

// NewTaskController create a controller to deal with pooltask's status.
func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, closeCh chan struct{}, wg *sync.WaitGroup, resultCh chan U) TaskController[T, U, C, CT, TF] {
func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, productExitCh chan struct{}, wg *sync.WaitGroup, inputCh chan Task[T], resultCh chan U) TaskController[T, U, C, CT, TF] {
return TaskController[T, U, C, CT, TF]{
pool: p,
taskID: taskID,
close: closeCh,
wg: wg,
resultCh: resultCh,
pool: p,
taskID: taskID,
productExitCh: productExitCh,
wg: wg,
resultCh: resultCh,
inputCh: inputCh,
}
}

// Wait is to wait the pool task to stop.
func (t *TaskController[T, U, C, CT, TF]) Wait() {
<-t.close
t.wg.Wait()
close(t.resultCh)
t.pool.DeleteTask(t.taskID)
}

// Stop is to send stop command to the task. But you still need to wait the task to stop.
func (t *TaskController[T, U, C, CT, TF]) Stop() {
close(t.productExitCh)
// Clear all the task in the task queue and mark all task complete.
// so that ```t.Wait``` is able to close resultCh
for range t.inputCh {
t.wg.Done()
}
t.pool.StopTask(t.TaskID())
// Clear the resultCh to avoid blocking the consumer put result into the channel and cannot exit.
channel.Clear(t.resultCh)
}

// TaskID is to get the task id.
Expand Down
6 changes: 4 additions & 2 deletions util/gpool/spmc/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"time"
)

const defaultTaskChanLen = 1

// Option represents the optional function.
type Option func(opts *Options)

Expand Down Expand Up @@ -103,8 +105,8 @@ func loadTaskOptions(options ...TaskOption) *TaskOptions {
if opts.ResultChanLen == 0 {
opts.ResultChanLen = uint64(opts.Concurrency)
}
if opts.ResultChanLen == 0 {
opts.ResultChanLen = uint64(opts.Concurrency)
if opts.TaskChanLen == 0 {
opts.TaskChanLen = defaultTaskChanLen
}
return opts
}
Expand Down
21 changes: 14 additions & 7 deletions util/gpool/spmc/spmcpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ func (p *Pool[T, U, C, CT, TF]) release() {
// There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent
// those callers blocking infinitely.
p.cond.Broadcast()
close(p.taskCh)
}

func isClose(exitCh chan struct{}) bool {
Expand Down Expand Up @@ -260,9 +259,9 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error),
taskID := p.NewTaskID()
var wg sync.WaitGroup
result := make(chan U, opt.ResultChanLen)
closeCh := make(chan struct{})
productExitCh := make(chan struct{})
inputCh := make(chan pooltask.Task[T], opt.TaskChanLen)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result)
p.taskManager.RegisterTask(taskID, int32(opt.Concurrency))
for i := 0; i < opt.Concurrency; i++ {
err := p.run()
Expand All @@ -274,15 +273,19 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error),
p.taskManager.AddSubTask(taskID, &taskBox)
p.taskCh <- &taskBox
}
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack"))
}
close(closeCh)
close(inputCh)
wg.Done()
}()
for {
if isClose(productExitCh) {
return
}
tasks, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
Expand Down Expand Up @@ -310,10 +313,10 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg
taskID := p.NewTaskID()
var wg sync.WaitGroup
result := make(chan U, opt.ResultChanLen)
closeCh := make(chan struct{})
productExitCh := make(chan struct{})
inputCh := make(chan pooltask.Task[T], opt.TaskChanLen)
p.taskManager.RegisterTask(taskID, int32(opt.Concurrency))
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result)
for i := 0; i < opt.Concurrency; i++ {
err := p.run()
if err == gpool.ErrPoolClosed {
Expand All @@ -324,15 +327,19 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg
p.taskManager.AddSubTask(taskID, &taskBox)
p.taskCh <- &taskBox
}
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack"))
}
close(closeCh)
close(inputCh)
wg.Done()
}()
for {
if isClose(productExitCh) {
return
}
task, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
Expand Down
62 changes: 54 additions & 8 deletions util/gpool/spmc/spmcpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestPool(t *testing.T) {
}
}
// add new task
resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4))
resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(5))

var count atomic.Uint32
var wg sync.WaitGroup
Expand Down Expand Up @@ -112,12 +112,55 @@ func TestStopPool(t *testing.T) {
require.Greater(t, result, 10)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
control.Stop()
}()
// Waiting task finishing
control.Wait()
wg.Wait()
// close pool
pool.ReleaseAndWait()
}

func TestStopPoolWithSlice(t *testing.T) {
type ConstArgs struct {
a int
}
myArgs := ConstArgs{a: 10}
// init the pool
// input type, output type, constArgs type
pool, err := NewSPMCPool[int, int, ConstArgs, any, pooltask.NilContext]("TestStopPoolWithSlice", 3, rmutil.UNKNOWN)
require.NoError(t, err)
pool.SetConsumerFunc(func(task int, constArgs ConstArgs, ctx any) int {
return task + constArgs.a
})

exit := make(chan struct{})

pfunc := func() ([]int, error) {
select {
case <-exit:
return nil, gpool.ErrProducerClosed
default:
return []int{1, 2, 3}, nil
}
}
// add new task
resultCh, control := pool.AddProduceBySlice(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4))

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for result := range resultCh {
require.Greater(t, result, 10)
control.Stop()
}
}()
// Waiting task finishing
control.Stop()
close(exit)
control.Wait()
// it should pass. Stop can be used after the pool is closed. we should prevent it from panic.
control.Stop()
wg.Wait()
// close pool
pool.ReleaseAndWait()
Expand Down Expand Up @@ -191,9 +234,12 @@ func testTunePool(t *testing.T, name string) {
for n := pool.Cap(); n > 1; n-- {
downclockPool(t, pool, tid)
}

// exit test
close(exit)
wg.Add(1)
go func() {
// exit test
control.Stop()
wg.Done()
}()
control.Wait()
wg.Wait()
// close pool
Expand Down

0 comments on commit 3c4976b

Please sign in to comment.