From 7c4afe088e3a6c1c5fca6c80900bf431dd865de9 Mon Sep 17 00:00:00 2001 From: Kaveh Shahbazian Date: Sat, 9 Oct 2021 22:04:03 +0200 Subject: [PATCH] refactor actor --- actor.go | 116 ++++++++++++ actor/actor.go | 116 ++++++++++++ actor/actor_test.go | 223 +++++++++++++++++++++++ actor/callbacks_spy_test.go | 109 +++++++++++ actor_test.go | 236 ++++++++++++++++++++++++ callbacks_spy_test.go | 109 +++++++++++ go.mod | 8 +- worker-pool.go | 285 ++++++++++++++++------------- worker-pool_test.go | 347 ++++++++++++++---------------------- 9 files changed, 1214 insertions(+), 335 deletions(-) create mode 100644 actor.go create mode 100644 actor/actor.go create mode 100644 actor/actor_test.go create mode 100644 actor/callbacks_spy_test.go create mode 100644 actor_test.go create mode 100644 callbacks_spy_test.go diff --git a/actor.go b/actor.go new file mode 100644 index 0000000..6c43d95 --- /dev/null +++ b/actor.go @@ -0,0 +1,116 @@ +// see LICENSE file + +package spool + +import ( + "context" + "time" +) + +func Start(ctx context.Context, mailbox Mailbox, callbacks Callbacks, options ...Option) { + opts := applyOptions(options...) + start(ctx, mailbox, callbacks, opts) +} + +func start(ctx context.Context, mailbox Mailbox, callbacks Callbacks, opts actorOptions) { + go func() { + if started != nil { + started(mailbox) + } + if stopped != nil { + defer stopped(mailbox) + } + + var ( + absoluteTimeout = opts.absoluteTimeout + idleTimeout = opts.idleTimeout + ) + + var absoluteTimeoutSignal, idleTimeoutSignal <-chan time.Time + if absoluteTimeout > 0 { + absoluteTimeoutSignal = time.After(absoluteTimeout) + } + + var requestCount RequestCount + for { + if requestCount > 0 && opts.respawnAfter > 0 && opts.respawnAfter <= requestCount { + start(ctx, mailbox, callbacks, opts) + return + } + + if idleTimeout > 0 { + idleTimeoutSignal = time.After(idleTimeout) + } + + select { + case <-absoluteTimeoutSignal: + callbacks.Stopped() + return + case <-idleTimeoutSignal: + if opts.respawnAfter > 0 { + start(ctx, mailbox, callbacks, opts) + return + } + callbacks.Stopped() + return + case <-ctx.Done(): + callbacks.Stopped() + return + case v, ok := <-mailbox: + if !ok { + callbacks.Stopped() + return + } + callbacks.Received(v) + requestCount++ + } + } + }() +} + +type ( + Callbacks interface { + Received(T) + Stopped() + } + + Mailbox <-chan T + + Option func(actorOptions) actorOptions + + actorOptions struct { + absoluteTimeout time.Duration + idleTimeout time.Duration + respawnAfter RequestCount + } + + RequestCount int + MailboxSize int + + T = func() +) + +func WithAbsoluteTimeout(timeout time.Duration) Option { + return func(opts actorOptions) actorOptions { opts.absoluteTimeout = timeout; return opts } +} + +func WithIdleTimeout(timeout time.Duration) Option { + return func(opts actorOptions) actorOptions { opts.idleTimeout = timeout; return opts } +} + +func WithRespawnAfter(respawnAfter RequestCount) Option { + return func(opts actorOptions) actorOptions { opts.respawnAfter = respawnAfter; return opts } +} + +func applyOptions(opts ...Option) actorOptions { + var options actorOptions + for _, fn := range opts { + options = fn(options) + } + return options +} + +var ( + started func(pool Mailbox) + stopped func(pool Mailbox) +) diff --git a/actor/actor.go b/actor/actor.go new file mode 100644 index 0000000..ea500de --- /dev/null +++ b/actor/actor.go @@ -0,0 +1,116 @@ +// see LICENSE file + +package actor + +import ( + "context" + "time" +) + +func Start(ctx context.Context, mailbox Mailbox, callbacks Callbacks, options ...Option) { + opts := applyOptions(options...) + start(ctx, mailbox, callbacks, opts) +} + +func start(ctx context.Context, mailbox Mailbox, callbacks Callbacks, opts actorOptions) { + go func() { + if started != nil { + started(mailbox) + } + if stopped != nil { + defer stopped(mailbox) + } + + var ( + absoluteTimeout = opts.absoluteTimeout + idleTimeout = opts.idleTimeout + ) + + var absoluteTimeoutSignal, idleTimeoutSignal <-chan time.Time + if absoluteTimeout > 0 { + absoluteTimeoutSignal = time.After(absoluteTimeout) + } + + var requestCount RequestCount + for { + if requestCount > 0 && opts.respawnAfter > 0 && opts.respawnAfter <= requestCount { + start(ctx, mailbox, callbacks, opts) + return + } + + if idleTimeout > 0 { + idleTimeoutSignal = time.After(idleTimeout) + } + + select { + case <-absoluteTimeoutSignal: + callbacks.Stopped() + return + case <-idleTimeoutSignal: + if opts.respawnAfter > 0 { + start(ctx, mailbox, callbacks, opts) + return + } + callbacks.Stopped() + return + case <-ctx.Done(): + callbacks.Stopped() + return + case v, ok := <-mailbox: + if !ok { + callbacks.Stopped() + return + } + callbacks.Received(v) + requestCount++ + } + } + }() +} + +type ( + Callbacks interface { + Received(T) + Stopped() + } + + Mailbox <-chan T + + Option func(actorOptions) actorOptions + + actorOptions struct { + absoluteTimeout time.Duration + idleTimeout time.Duration + respawnAfter RequestCount + } + + RequestCount int + MailboxSize int + + T = interface{} +) + +func WithAbsoluteTimeout(timeout time.Duration) Option { + return func(opts actorOptions) actorOptions { opts.absoluteTimeout = timeout; return opts } +} + +func WithIdleTimeout(timeout time.Duration) Option { + return func(opts actorOptions) actorOptions { opts.idleTimeout = timeout; return opts } +} + +func WithRespawnAfter(respawnAfter RequestCount) Option { + return func(opts actorOptions) actorOptions { opts.respawnAfter = respawnAfter; return opts } +} + +func applyOptions(opts ...Option) actorOptions { + var options actorOptions + for _, fn := range opts { + options = fn(options) + } + return options +} + +var ( + started func(pool Mailbox) + stopped func(pool Mailbox) +) diff --git a/actor/actor_test.go b/actor/actor_test.go new file mode 100644 index 0000000..d2c630b --- /dev/null +++ b/actor/actor_test.go @@ -0,0 +1,223 @@ +//go:generate moq -out callbacks_spy_test.go . Callbacks:CallbacksSpy +// see LICENSE file + +// install moq: +// $ go install github.com/matryer/moq@latest + +package actor + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMain(m *testing.M) { + started = func(box Mailbox) { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + mailboxStateWorkerStartCount[box]++ + } + stopped = func(box Mailbox) { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + mailboxStateWorkerStopCount[box]++ + } + + exitVal := m.Run() + + os.Exit(exitVal) +} + +func Test_should_call_received_concurrently(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.ReceivedFunc = func(T) {} + + Start(context.TODO(), mailbox, callbacks) + + mailbox <- 1 + mailbox <- 2 + mailbox <- 3 + + assert.Eventually(t, func() bool { return len(callbacks.ReceivedCalls()) == 3 }, + time.Millisecond*300, time.Millisecond*20) + + assert.EqualValues(t, 1, callbacks.ReceivedCalls()[0].IfaceVal) + assert.EqualValues(t, 2, callbacks.ReceivedCalls()[1].IfaceVal) + assert.EqualValues(t, 3, callbacks.ReceivedCalls()[2].IfaceVal) +} + +func Test_should_call_stopped_when_context_is_canceled(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + ctx, cancel := context.WithCancel(context.Background()) + callbacks.StoppedFunc = func() {} + + Start(ctx, mailbox, callbacks) + cancel() + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_not_call_received_when_context_is_canceled(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + ctx, cancel := context.WithCancel(context.Background()) + callbacks.StoppedFunc = func() {} + + Start(ctx, mailbox, callbacks) + cancel() + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) + + sendingStarted := make(chan struct{}) + go func() { + close(sendingStarted) + mailbox <- 1 + }() + <-sendingStarted + + assert.Never(t, func() bool { return len(callbacks.ReceivedCalls()) > 0 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_stop_after_absolute_timeout(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks, WithAbsoluteTimeout(time.Millisecond*50)) + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_stop_when_mailbox_is_closed(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks) + close(mailbox) + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) + assert.Equal(t, 0, len(callbacks.ReceivedCalls())) +} + +func Test_should_stop_after_idle_timeout_elapsed(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks, WithIdleTimeout(time.Millisecond*100)) + + assert.Never(t, func() bool { return len(callbacks.StoppedCalls()) > 0 }, + time.Millisecond*100, time.Millisecond*20) + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_respawn_after_receiving_n_messages(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + callbacks.ReceivedFunc = func(T) {} + + Start(context.Background(), mailbox, callbacks, WithRespawnAfter(10)) + + go func() { + for i := 0; i < 20; i++ { + mailbox <- i + } + }() + + assert.Eventually(t, func() bool { return assert.EqualValues(t, 3, getNumberOfStarts(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) + assert.Eventually(t, func() bool { return assert.EqualValues(t, 2, getNumberOfStops(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_not_respawn_if_not_provided(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + callbacks.ReceivedFunc = func(T) {} + + Start(context.Background(), mailbox, callbacks) + + go func() { + for i := 0; i < 20; i++ { + mailbox <- i + } + }() + + assert.Eventually(t, func() bool { return assert.EqualValues(t, 1, getNumberOfStarts(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) + assert.Eventually(t, func() bool { return assert.EqualValues(t, 0, getNumberOfStops(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_respawn_after_idle_timeout_elapsed_if_respawn_count_is_provided(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks, + WithIdleTimeout(time.Millisecond*100), + WithRespawnAfter(100)) + + assert.Eventually(t, func() bool { return getNumberOfStarts(mailbox) == 2 }, + time.Millisecond*300, time.Millisecond*20) + assert.Eventually(t, func() bool { return getNumberOfStops(mailbox) == 1 }, + time.Millisecond*300, time.Millisecond*20) + + assert.Never(t, func() bool { return len(callbacks.StoppedCalls()) > 0 }, + time.Millisecond*100, time.Millisecond*20) +} + +func getNumberOfStarts(box Mailbox) int { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + + return mailboxStateWorkerStartCount[box] +} + +func getNumberOfStops(box Mailbox) int { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + + return mailboxStateWorkerStopCount[box] +} + +var ( + mailboxStateWorkerStopCount = make(map[Mailbox]int) + mailboxStateWorkerStartCount = make(map[Mailbox]int) + accessMailboxState = &sync.Mutex{} +) diff --git a/actor/callbacks_spy_test.go b/actor/callbacks_spy_test.go new file mode 100644 index 0000000..3464f6c --- /dev/null +++ b/actor/callbacks_spy_test.go @@ -0,0 +1,109 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package actor + +import ( + "sync" +) + +// Ensure, that CallbacksSpy does implement Callbacks. +// If this is not the case, regenerate this file with moq. +var _ Callbacks = &CallbacksSpy{} + +// CallbacksSpy is a mock implementation of Callbacks. +// +// func TestSomethingThatUsesCallbacks(t *testing.T) { +// +// // make and configure a mocked Callbacks +// mockedCallbacks := &CallbacksSpy{ +// ReceivedFunc: func(ifaceVal interface{}) { +// panic("mock out the Received method") +// }, +// StoppedFunc: func() { +// panic("mock out the Stopped method") +// }, +// } +// +// // use mockedCallbacks in code that requires Callbacks +// // and then make assertions. +// +// } +type CallbacksSpy struct { + // ReceivedFunc mocks the Received method. + ReceivedFunc func(ifaceVal interface{}) + + // StoppedFunc mocks the Stopped method. + StoppedFunc func() + + // calls tracks calls to the methods. + calls struct { + // Received holds details about calls to the Received method. + Received []struct { + // IfaceVal is the ifaceVal argument value. + IfaceVal interface{} + } + // Stopped holds details about calls to the Stopped method. + Stopped []struct { + } + } + lockReceived sync.RWMutex + lockStopped sync.RWMutex +} + +// Received calls ReceivedFunc. +func (mock *CallbacksSpy) Received(ifaceVal interface{}) { + if mock.ReceivedFunc == nil { + panic("CallbacksSpy.ReceivedFunc: method is nil but Callbacks.Received was just called") + } + callInfo := struct { + IfaceVal interface{} + }{ + IfaceVal: ifaceVal, + } + mock.lockReceived.Lock() + mock.calls.Received = append(mock.calls.Received, callInfo) + mock.lockReceived.Unlock() + mock.ReceivedFunc(ifaceVal) +} + +// ReceivedCalls gets all the calls that were made to Received. +// Check the length with: +// len(mockedCallbacks.ReceivedCalls()) +func (mock *CallbacksSpy) ReceivedCalls() []struct { + IfaceVal interface{} +} { + var calls []struct { + IfaceVal interface{} + } + mock.lockReceived.RLock() + calls = mock.calls.Received + mock.lockReceived.RUnlock() + return calls +} + +// Stopped calls StoppedFunc. +func (mock *CallbacksSpy) Stopped() { + if mock.StoppedFunc == nil { + panic("CallbacksSpy.StoppedFunc: method is nil but Callbacks.Stopped was just called") + } + callInfo := struct { + }{} + mock.lockStopped.Lock() + mock.calls.Stopped = append(mock.calls.Stopped, callInfo) + mock.lockStopped.Unlock() + mock.StoppedFunc() +} + +// StoppedCalls gets all the calls that were made to Stopped. +// Check the length with: +// len(mockedCallbacks.StoppedCalls()) +func (mock *CallbacksSpy) StoppedCalls() []struct { +} { + var calls []struct { + } + mock.lockStopped.RLock() + calls = mock.calls.Stopped + mock.lockStopped.RUnlock() + return calls +} diff --git a/actor_test.go b/actor_test.go new file mode 100644 index 0000000..0129c7e --- /dev/null +++ b/actor_test.go @@ -0,0 +1,236 @@ +//go:generate moq -out callbacks_spy_test.go . Callbacks:CallbacksSpy +// see LICENSE file + +// install moq: +// $ go install github.com/matryer/moq@latest + +package spool + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMain(m *testing.M) { + started = func(box Mailbox) { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + mailboxStateWorkerStartCount[box]++ + } + stopped = func(box Mailbox) { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + mailboxStateWorkerStopCount[box]++ + } + + exitVal := m.Run() + + os.Exit(exitVal) +} + +func Test_should_call_received_concurrently(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.ReceivedFunc = func(T) {} + + Start(context.TODO(), mailbox, callbacks) + + ids := make(chan int, 6) + fn1 := func() { ids <- 1 } + fn2 := func() { ids <- 2 } + fn3 := func() { ids <- 3 } + + mailbox <- fn1 + mailbox <- fn2 + mailbox <- fn3 + + assert.Eventually(t, func() bool { return len(callbacks.ReceivedCalls()) == 3 }, + time.Millisecond*300, time.Millisecond*20) + + fn1() + callbacks.ReceivedCalls()[0].Fn() + assert.EqualValues(t, <-ids, <-ids) + + fn2() + callbacks.ReceivedCalls()[1].Fn() + assert.EqualValues(t, <-ids, <-ids) + + fn3() + callbacks.ReceivedCalls()[2].Fn() + assert.EqualValues(t, <-ids, <-ids) +} + +func Test_should_call_stopped_when_context_is_canceled(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + ctx, cancel := context.WithCancel(context.Background()) + callbacks.StoppedFunc = func() {} + + Start(ctx, mailbox, callbacks) + cancel() + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_not_call_received_when_context_is_canceled(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + ctx, cancel := context.WithCancel(context.Background()) + callbacks.StoppedFunc = func() {} + + Start(ctx, mailbox, callbacks) + cancel() + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) + + sendingStarted := make(chan struct{}) + go func() { + close(sendingStarted) + mailbox <- func() {} + }() + <-sendingStarted + + assert.Never(t, func() bool { return len(callbacks.ReceivedCalls()) > 0 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_stop_after_absolute_timeout(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks, WithAbsoluteTimeout(time.Millisecond*50)) + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_stop_when_mailbox_is_closed(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks) + close(mailbox) + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) + assert.Equal(t, 0, len(callbacks.ReceivedCalls())) +} + +func Test_should_stop_after_idle_timeout_elapsed(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks, WithIdleTimeout(time.Millisecond*100)) + + assert.Never(t, func() bool { return len(callbacks.StoppedCalls()) > 0 }, + time.Millisecond*100, time.Millisecond*20) + + assert.Eventually(t, func() bool { return len(callbacks.StoppedCalls()) == 1 }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_respawn_after_receiving_n_messages(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + callbacks.ReceivedFunc = func(T) {} + + Start(context.Background(), mailbox, callbacks, WithRespawnAfter(10)) + + go func() { + for i := 0; i < 20; i++ { + mailbox <- func() {} + } + }() + + assert.Eventually(t, func() bool { return assert.EqualValues(t, 3, getNumberOfStarts(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) + assert.Eventually(t, func() bool { return assert.EqualValues(t, 2, getNumberOfStops(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_not_respawn_if_not_provided(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + callbacks.ReceivedFunc = func(T) {} + + Start(context.Background(), mailbox, callbacks) + + go func() { + for i := 0; i < 20; i++ { + mailbox <- func() {} + } + }() + + assert.Eventually(t, func() bool { return assert.EqualValues(t, 1, getNumberOfStarts(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) + assert.Eventually(t, func() bool { return assert.EqualValues(t, 0, getNumberOfStops(mailbox)) }, + time.Millisecond*300, time.Millisecond*20) +} + +func Test_should_respawn_after_idle_timeout_elapsed_if_respawn_count_is_provided(t *testing.T) { + var ( + mailbox = make(chan T) + callbacks = &CallbacksSpy{} + ) + callbacks.StoppedFunc = func() {} + + Start(context.Background(), mailbox, callbacks, + WithIdleTimeout(time.Millisecond*100), + WithRespawnAfter(100)) + + assert.Eventually(t, func() bool { return getNumberOfStarts(mailbox) == 2 }, + time.Millisecond*300, time.Millisecond*20) + assert.Eventually(t, func() bool { return getNumberOfStops(mailbox) == 1 }, + time.Millisecond*300, time.Millisecond*20) + + assert.Never(t, func() bool { return len(callbacks.StoppedCalls()) > 0 }, + time.Millisecond*100, time.Millisecond*20) +} + +func getNumberOfStarts(box Mailbox) int { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + + return mailboxStateWorkerStartCount[box] +} + +func getNumberOfStops(box Mailbox) int { + accessMailboxState.Lock() + defer accessMailboxState.Unlock() + + return mailboxStateWorkerStopCount[box] +} + +var ( + mailboxStateWorkerStopCount = make(map[Mailbox]int) + mailboxStateWorkerStartCount = make(map[Mailbox]int) + accessMailboxState = &sync.Mutex{} +) diff --git a/callbacks_spy_test.go b/callbacks_spy_test.go new file mode 100644 index 0000000..4025a96 --- /dev/null +++ b/callbacks_spy_test.go @@ -0,0 +1,109 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package spool + +import ( + "sync" +) + +// Ensure, that CallbacksSpy does implement Callbacks. +// If this is not the case, regenerate this file with moq. +var _ Callbacks = &CallbacksSpy{} + +// CallbacksSpy is a mock implementation of Callbacks. +// +// func TestSomethingThatUsesCallbacks(t *testing.T) { +// +// // make and configure a mocked Callbacks +// mockedCallbacks := &CallbacksSpy{ +// ReceivedFunc: func(fn func()) { +// panic("mock out the Received method") +// }, +// StoppedFunc: func() { +// panic("mock out the Stopped method") +// }, +// } +// +// // use mockedCallbacks in code that requires Callbacks +// // and then make assertions. +// +// } +type CallbacksSpy struct { + // ReceivedFunc mocks the Received method. + ReceivedFunc func(fn func()) + + // StoppedFunc mocks the Stopped method. + StoppedFunc func() + + // calls tracks calls to the methods. + calls struct { + // Received holds details about calls to the Received method. + Received []struct { + // Fn is the fn argument value. + Fn func() + } + // Stopped holds details about calls to the Stopped method. + Stopped []struct { + } + } + lockReceived sync.RWMutex + lockStopped sync.RWMutex +} + +// Received calls ReceivedFunc. +func (mock *CallbacksSpy) Received(fn func()) { + if mock.ReceivedFunc == nil { + panic("CallbacksSpy.ReceivedFunc: method is nil but Callbacks.Received was just called") + } + callInfo := struct { + Fn func() + }{ + Fn: fn, + } + mock.lockReceived.Lock() + mock.calls.Received = append(mock.calls.Received, callInfo) + mock.lockReceived.Unlock() + mock.ReceivedFunc(fn) +} + +// ReceivedCalls gets all the calls that were made to Received. +// Check the length with: +// len(mockedCallbacks.ReceivedCalls()) +func (mock *CallbacksSpy) ReceivedCalls() []struct { + Fn func() +} { + var calls []struct { + Fn func() + } + mock.lockReceived.RLock() + calls = mock.calls.Received + mock.lockReceived.RUnlock() + return calls +} + +// Stopped calls StoppedFunc. +func (mock *CallbacksSpy) Stopped() { + if mock.StoppedFunc == nil { + panic("CallbacksSpy.StoppedFunc: method is nil but Callbacks.Stopped was just called") + } + callInfo := struct { + }{} + mock.lockStopped.Lock() + mock.calls.Stopped = append(mock.calls.Stopped, callInfo) + mock.lockStopped.Unlock() + mock.StoppedFunc() +} + +// StoppedCalls gets all the calls that were made to Stopped. +// Check the length with: +// len(mockedCallbacks.StoppedCalls()) +func (mock *CallbacksSpy) StoppedCalls() []struct { +} { + var calls []struct { + } + mock.lockStopped.RLock() + calls = mock.calls.Stopped + mock.lockStopped.RUnlock() + return calls +} diff --git a/go.mod b/go.mod index 381f469..71c8be3 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,11 @@ module github.com/dc0d/spool -go 1.16 +go 1.17 require github.com/stretchr/testify v1.7.0 + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/worker-pool.go b/worker-pool.go index 1ded92b..dc1e693 100644 --- a/worker-pool.go +++ b/worker-pool.go @@ -3,25 +3,18 @@ package spool import ( + "context" "log" - "time" ) type WorkerPool chan func() -// New creates a new workerpool. If initialPoolSize is zero, no initial workers will be started. -// To have more workers, the Grow method should be used. -func New(mailboxSize MailboxSize, initialPoolSize int, opts ...GrowthOption) WorkerPool { +// New creates a new WorkerPool without any initial workers. To spawn workers, Grow must be called. +func New(mailboxSize MailboxSize) WorkerPool { if mailboxSize < 0 { mailboxSize = 0 } - - var pool WorkerPool = make(chan func(), mailboxSize) - if initialPoolSize > 0 { - pool.Grow(initialPoolSize, opts...) - } - - return pool + return make(chan func(), mailboxSize) } func (pool WorkerPool) Stop() { @@ -29,136 +22,186 @@ func (pool WorkerPool) Stop() { } // Blocking will panic, if the workerpool is stopped. -func (pool WorkerPool) Blocking(callback func()) { +func (pool WorkerPool) Blocking(ctx context.Context, callback func()) error { done := make(chan struct{}) - pool <- func() { defer close(done); callback() } + select { + case <-ctx.Done(): + return ctx.Err() + case pool <- func() { defer close(done); callback() }: + } <-done + return nil } // SemiBlocking sends the job to the worker in a non-blocking manner, as long as the mailbox is not full. // After that, it becomes blocking until there is an empty space in the mailbox. // If the workerpool is stopped, SemiBlocking will panic. -func (pool WorkerPool) SemiBlocking(callback func()) { - pool <- callback -} - -func (pool WorkerPool) Grow(growth int, opts ...GrowthOption) { - options := applyOptions(opts...) - - if growth <= 0 { - growth = 1 - } - - for i := 0; i < growth; i++ { - pool.start(options) +func (pool WorkerPool) SemiBlocking(ctx context.Context, callback func()) error { + select { + case <-ctx.Done(): + return ctx.Err() + case pool <- callback: } + return nil } -func (pool WorkerPool) start(options growthOptions) { - go pool.worker(options) +func (pool WorkerPool) Grow(ctx context.Context, growth int, options ...Option) { + pool.grow(ctx, growth, nil, options...) } -func (pool WorkerPool) worker(options growthOptions) { - if workerStarted != nil { - workerStarted(pool) - } - if workerStopped != nil { - defer workerStopped(pool) - } - - var ( - absoluteTimeout = options.absoluteTimeout - idleTimeout = options.idleTimeout - stopSignal = options.stopSignal - ) - - var absoluteTimeoutSignal, idleTimeoutSignal <-chan time.Time - if absoluteTimeout > 0 { - absoluteTimeoutSignal = time.After(absoluteTimeout) - } - - var requestCount RequestCount - for { - if options.respawnAfter > 0 && options.respawnAfter <= requestCount { - pool.start(options) - return - } - - if idleTimeout > 0 { - idleTimeoutSignal = time.After(idleTimeout) - } - - select { - case <-absoluteTimeoutSignal: - return - case <-idleTimeoutSignal: - if options.respawnAfter > 0 { - pool.start(options) - } - return - case <-stopSignal: - return - case callback, ok := <-pool: - if !ok { - return - } - execCallback(callback) - requestCount++ +func (pool WorkerPool) grow(ctx context.Context, growth int, executorFactory func() Callbacks, options ...Option) { + var mailbox <-chan func() = pool + for i := 0; i < growth; i++ { + var exec Callbacks = defaultExecutor{} + if executorFactory != nil { + exec = executorFactory() } + Start(ctx, mailbox, exec, options...) } } -func execCallback(callback func()) { +type defaultExecutor struct{} + +func (obj defaultExecutor) Received(fn T) { defer func() { if e := recover(); e != nil { log.Println(e) // TODO: } }() - - callback() -} - -var ( - workerStarted func(pool WorkerPool) - workerStopped func(pool WorkerPool) -) - -// growth options - -func WithAbsoluteTimeout(timeout time.Duration) GrowthOption { - return func(opts growthOptions) growthOptions { opts.absoluteTimeout = timeout; return opts } -} - -func WithIdleTimeout(timeout time.Duration) GrowthOption { - return func(opts growthOptions) growthOptions { opts.idleTimeout = timeout; return opts } -} - -func WithStopSignal(stopSignal <-chan struct{}) GrowthOption { - return func(opts growthOptions) growthOptions { opts.stopSignal = stopSignal; return opts } -} - -func WithRespawnAfter(respawnAfter RequestCount) GrowthOption { - return func(opts growthOptions) growthOptions { opts.respawnAfter = respawnAfter; return opts } -} - -type GrowthOption func(growthOptions) growthOptions - -type growthOptions struct { - absoluteTimeout time.Duration - idleTimeout time.Duration - stopSignal <-chan struct{} - respawnAfter RequestCount + fn() } -func applyOptions(opts ...GrowthOption) growthOptions { - var options growthOptions - for _, fn := range opts { - options = fn(options) - } - return options -} - -type ( - RequestCount int - MailboxSize int -) +func (obj defaultExecutor) Stopped() {} + +// import ( +// "log" +// "time" +// ) + +// // New creates a new workerpool. If initialPoolSize is zero, no initial workers will be started. +// // To have more workers, the Grow method should be used. +// func New(mailboxSize MailboxSize, initialPoolSize int, opts ...GrowthOption) WorkerPool { +// if mailboxSize < 0 { +// mailboxSize = 0 +// } + +// var pool WorkerPool = make(chan func(), mailboxSize) +// if initialPoolSize > 0 { +// pool.Grow(initialPoolSize, opts...) +// } + +// return pool +// } + +// func (pool WorkerPool) Stop() { +// close(pool) +// } + +// func (pool WorkerPool) start(options growthOptions) { +// go pool.worker(options) +// } + +// func (pool WorkerPool) worker(options growthOptions) { +// if workerStarted != nil { +// workerStarted(pool) +// } +// if workerStopped != nil { +// defer workerStopped(pool) +// } + +// var ( +// absoluteTimeout = options.absoluteTimeout +// idleTimeout = options.idleTimeout +// stopSignal = options.stopSignal +// ) + +// var absoluteTimeoutSignal, idleTimeoutSignal <-chan time.Time +// if absoluteTimeout > 0 { +// absoluteTimeoutSignal = time.After(absoluteTimeout) +// } + +// var requestCount RequestCount +// for { +// if options.respawnAfter > 0 && options.respawnAfter <= requestCount { +// pool.start(options) +// return +// } + +// if idleTimeout > 0 { +// idleTimeoutSignal = time.After(idleTimeout) +// } + +// select { +// case <-absoluteTimeoutSignal: +// return +// case <-idleTimeoutSignal: +// if options.respawnAfter > 0 { +// pool.start(options) +// } +// return +// case <-stopSignal: +// return +// case callback, ok := <-pool: +// if !ok { +// return +// } +// execCallback(callback) +// requestCount++ +// } +// } +// } + +// func execCallback(callback func()) { +// defer func() { +// if e := recover(); e != nil { +// log.Println(e) // TODO: +// } +// }() + +// callback() +// } + +// var ( +// workerStarted func(pool WorkerPool) +// workerStopped func(pool WorkerPool) +// ) + +// // growth options + +// func WithAbsoluteTimeout(timeout time.Duration) GrowthOption { +// return func(opts growthOptions) growthOptions { opts.absoluteTimeout = timeout; return opts } +// } + +// func WithIdleTimeout(timeout time.Duration) GrowthOption { +// return func(opts growthOptions) growthOptions { opts.idleTimeout = timeout; return opts } +// } + +// func WithStopSignal(stopSignal <-chan struct{}) GrowthOption { +// return func(opts growthOptions) growthOptions { opts.stopSignal = stopSignal; return opts } +// } + +// func WithRespawnAfter(respawnAfter RequestCount) GrowthOption { +// return func(opts growthOptions) growthOptions { opts.respawnAfter = respawnAfter; return opts } +// } + +// type GrowthOption func(growthOptions) growthOptions + +// type growthOptions struct { +// absoluteTimeout time.Duration +// idleTimeout time.Duration +// stopSignal <-chan struct{} +// respawnAfter RequestCount +// } + +// func applyOptions(opts ...GrowthOption) growthOptions { +// var options growthOptions +// for _, fn := range opts { +// options = fn(options) +// } +// return options +// } + +// type ( +// RequestCount int +// MailboxSize int +// ) diff --git a/worker-pool_test.go b/worker-pool_test.go index bf44679..ef704c3 100644 --- a/worker-pool_test.go +++ b/worker-pool_test.go @@ -3,8 +3,8 @@ package spool import ( + "context" "fmt" - "os" "sync" "sync/atomic" "testing" @@ -13,42 +13,18 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMain(m *testing.M) { - workerStarted = func(pool WorkerPool) { incNumberOfWorkers(pool, 1) } - workerStopped = func(pool WorkerPool) { decNumberOfWorkers(pool, 1) } - - exitVal := m.Run() - - os.Exit(exitVal) -} - func Test_WorkerPool_New(t *testing.T) { t.Run(`should set default mailbox size to zero`, func(t *testing.T) { - pool := New(-1, 0) + pool := New(-1) defer pool.Stop() assert.True(t, len(pool) == 0) }) - t.Run(`should start n initial workers`, func(t *testing.T) { - n := 100 // initial workers - pool := New(-1, n) + t.Run(`should not start any initial workers`, func(t *testing.T) { + pool := New(-1) defer pool.Stop() - var count int64 - stop := make(chan struct{}) - defer close(stop) - for i := 0; i < n; i++ { - pool.SemiBlocking(func() { - atomic.AddInt64(&count, 1) - <-stop - }) - } - - assert.Eventually(t, func() bool { - return atomic.LoadInt64(&count) == int64(n) - }, time.Millisecond*300, time.Millisecond*20) - assert.Never(t, func() bool { select { case pool <- func() { panic("should not run") }: @@ -60,28 +36,49 @@ func Test_WorkerPool_New(t *testing.T) { }) } +func Test_WorkerPool_grow_should_spawn_workers_equal_to_growth(t *testing.T) { + var ( + ctx = context.Background() + growth = 100 + options []Option + ) + pool := New(-1) + exec := &CallbacksSpy{ + StoppedFunc: func() {}, + } + executorFactory := func() Callbacks { return exec } + + pool.grow(ctx, growth, executorFactory, options...) + pool.Stop() + + assert.Eventually(t, func() bool { + return len(exec.StoppedCalls()) == growth + }, time.Millisecond*300, time.Millisecond*20) +} + func Test_WorkerPool_Blocking_should_serialize_the_jobs(t *testing.T) { const n = 1000 - pool := New(10, 1) + pool := New(10) defer pool.Stop() - + pool.Grow(context.Background(), 1) var ( counter, previous int64 ) - wg := &sync.WaitGroup{} wg.Add(n) start := make(chan struct{}) for i := 0; i < n; i++ { - go pool.Blocking(func() { - defer wg.Done() - <-start - - previous = atomic.LoadInt64(&counter) - next := atomic.AddInt64(&counter, 1) - assert.Equal(t, previous+1, next) - }) + go func() { + _ = pool.Blocking(context.Background(), func() { + defer wg.Done() + <-start + + previous = atomic.LoadInt64(&counter) + next := atomic.AddInt64(&counter, 1) + assert.Equal(t, previous+1, next) + }) + }() } close(start) // signal all jobs they are green to go wg.Wait() @@ -89,17 +86,28 @@ func Test_WorkerPool_Blocking_should_serialize_the_jobs(t *testing.T) { assert.Equal(t, int64(n), counter) } -func Test_WorkerPool_Nonblocking_should_just_put_job_in_the_mailbox(t *testing.T) { - const n = 1000 - pool := New(n, 1) +func Test_WorkerPool_Blocking_should_respect_context_cancellation(t *testing.T) { + pool := New(-1) defer pool.Stop() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := pool.Blocking(ctx, func() { panic("should not be called") }) + assert.Equal(t, context.Canceled, err) +} + +func Test_WorkerPool_SemiBlocking_should_just_put_job_in_the_mailbox(t *testing.T) { + const n = 1000 + pool := New(n) + defer pool.Stop() + pool.Grow(context.Background(), 1) var counter int64 = 0 wg := &sync.WaitGroup{} wg.Add(n) for i := 0; i < n; i++ { - pool.SemiBlocking(func() { + _ = pool.SemiBlocking(context.Background(), func() { defer wg.Done() atomic.AddInt64(&counter, 1) }) @@ -109,237 +117,149 @@ func Test_WorkerPool_Nonblocking_should_just_put_job_in_the_mailbox(t *testing.T assert.Equal(t, int64(n), counter) } +func Test_WorkerPool_SemiBlocking_should_respect_context_cancellation(t *testing.T) { + pool := New(-1) + defer pool.Stop() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := pool.SemiBlocking(ctx, func() { panic("should not be called") }) + + assert.Equal(t, context.Canceled, err) +} + func Test_WorkerPool_should_not_stop_because_of_panic(t *testing.T) { - pool := New(1, 1) + pool := New(1) defer pool.Stop() + pool.Grow(context.Background(), 1) - pool.Blocking(func() { + _ = pool.Blocking(context.Background(), func() { panic("some error") }) counter := 0 - pool.Blocking(func() { + _ = pool.Blocking(context.Background(), func() { counter++ }) assert.Equal(t, 1, counter) } -// these tests are good enough for now - still the temporal dependency - -func Test_WorkerPool_Grow_should_spin_up_at_least_one_new_worker(t *testing.T) { - increased := 1 - pool := New(9, 1) - defer pool.Stop() - - negativeOrZero := 0 - pool.Grow(negativeOrZero) - - expectedNumberOfWorkers := increased /* the one extra worker */ + 1 /* the default worker */ - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) -} - -func Test_WorkerPool_Grow_should_spin_up_multiple_new_workers(t *testing.T) { - increased := 10 - pool := New(9, 1) - defer pool.Stop() - - pool.Grow(increased) - - expectedNumberOfWorkers := increased + 1 - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) -} - func Test_WorkerPool_Grow_should_stop_extra_workers_with_absolute_timeout(t *testing.T) { increased := 10 absoluteTimeout := time.Millisecond * 10 - pool := New(9, 1) + pool := New(9) defer pool.Stop() + exec := &CallbacksSpy{ + StoppedFunc: func() {}, + } + executorFactory := func() Callbacks { return exec } - pool.Grow(increased, WithAbsoluteTimeout(absoluteTimeout)) + pool.grow(context.Background(), increased, executorFactory, WithAbsoluteTimeout(absoluteTimeout)) - expectedNumberOfWorkers := 1 - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) + assert.Eventually(t, func() bool { + return len(exec.StoppedCalls()) == increased + }, time.Millisecond*500, time.Millisecond*50) } func Test_WorkerPool_Grow_should_stop_extra_workers_with_idle_timeout_when_there_are_no_more_jobs(t *testing.T) { const n = 1000 increased := 10 idleTimeout := time.Millisecond * 50 - pool := New(100, 1) + pool := New(100) defer pool.Stop() + exec := &CallbacksSpy{ + StoppedFunc: func() {}, + ReceivedFunc: func(fn func()) { fn() }, + } + executorFactory := func() Callbacks { return exec } start := make(chan struct{}, n) wg := &sync.WaitGroup{} wg.Add(n) for i := 0; i < n; i++ { - go pool.SemiBlocking(func() { - defer wg.Done() - <-start - }) + go func() { + _ = pool.SemiBlocking(context.Background(), func() { + defer wg.Done() + <-start + }) + }() } + pool.grow(context.Background(), increased, executorFactory, WithIdleTimeout(idleTimeout)) - pool.Grow(increased, WithIdleTimeout(idleTimeout)) - expectedNumberOfWorkers := 1 + increased + expectedNumberOfWorkers := increased assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) - - go func() { - for i := 0; i < n; i++ { - start <- struct{}{} - } - }() + return expectedNumberOfWorkers == len(exec.ReceivedCalls()) + }, time.Millisecond*1000, time.Millisecond*20, + "expected %v actual %v", expectedNumberOfWorkers, func() int { return len(exec.ReceivedCalls()) }()) + + close(start) wg.Wait() - expectedNumberOfWorkers = 1 - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) + expectedNumberOfStoppedWorkers := 10 + assert.Eventually(t, func() bool { + return expectedNumberOfStoppedWorkers == len(exec.StoppedCalls()) + }, time.Millisecond*5000, time.Millisecond*50) } -func Test_WorkerPool_Grow_should_stop_extra_workers_with_explicit_stop_signal(t *testing.T) { +func Test_WorkerPool_Grow_should_stop_extra_workers_when_context_is_canceled(t *testing.T) { increased := 10 - stopSignal := make(chan struct{}) - pool := New(9, 1) + pool := New(10) defer pool.Stop() - - pool.Grow(increased, WithStopSignal(stopSignal)) - - expectedNumberOfWorkers := 1 + increased - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) - - close(stopSignal) - expectedNumberOfWorkers = 1 - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) -} - -func Test_WorkerPool_Grow_should_respawn_after_a_certain_number_of_requests(t *testing.T) { - pool := New(9, 1, WithRespawnAfter(10)) - defer pool.Stop() - - expectedNumberOfStarts := 1 // one initial start - assert.Eventually(t, func() bool { - return expectedNumberOfStarts == getNumberOfStarts(pool) - }, time.Millisecond*500, time.Millisecond*50) - - for i := 0; i < 11; i++ { - pool.Blocking(func() {}) + exec := &CallbacksSpy{ + StoppedFunc: func() {}, + ReceivedFunc: func(fn func()) { fn() }, } + executorFactory := func() Callbacks { return exec } + pool.grow(context.Background(), 1, executorFactory) - expectedNumberOfStarts = 2 - assert.Eventually(t, func() bool { - return expectedNumberOfStarts == getNumberOfStarts(pool) - }, time.Millisecond*500, time.Millisecond*50) -} + ctx, cancel := context.WithCancel(context.Background()) + pool.grow(ctx, increased, executorFactory) -func Test_WorkerPool_Grow_should_respawn_after_a_certain_timespan_if_reapawnAfter_is_provided(t *testing.T) { - pool := New(9, 1, WithRespawnAfter(1000), WithIdleTimeout(time.Millisecond*50)) - defer pool.Stop() + cancel() - time.Sleep(time.Millisecond * 190) - expectedNumberOfStarts := 4 - assert.Equal(t, expectedNumberOfStarts, getNumberOfStarts(pool)) + expectedNumberOfStoppedWorkers := increased + assert.Eventually(t, func() bool { + return expectedNumberOfStoppedWorkers == len(exec.StoppedCalls()) + }, time.Millisecond*5000, time.Millisecond*50) } -// - func Test_WorkerPool_Stop_should_close_the_pool(t *testing.T) { - pool := New(9, 1) + pool := New(9) pool.Stop() assert.Panics(t, func() { - pool.SemiBlocking(func() {}) + _ = pool.SemiBlocking(context.Background(), func() {}) }) } func Test_WorkerPool_Stop_should_stop_the_workers(t *testing.T) { - pool := New(9, 1) - + pool := New(9) increased := 10 - pool.Grow(increased) - - expectedNumberOfWorkers := 1 + increased - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) + exec := &CallbacksSpy{ + StoppedFunc: func() {}, + ReceivedFunc: func(fn func()) { fn() }, + } + executorFactory := func() Callbacks { return exec } + pool.grow(context.Background(), increased, executorFactory) pool.Stop() - assert.Panics(t, func() { - pool.SemiBlocking(func() {}) - }) - - expectedNumberOfWorkers = 0 - assert.Eventuallyf(t, func() bool { - return expectedNumberOfWorkers == getNumberOfWorkers(pool) - }, time.Millisecond*500, time.Millisecond*50, - "expectedNumberOfWorkers: %v, actual: %v", expectedNumberOfWorkers, getNumberOfWorkers(pool)) -} - -// - -func getNumberOfWorkers(pool WorkerPool) int { - accessWorkerPoolState.RLock() - defer accessWorkerPoolState.RUnlock() - - return workerpoolStateWorkerCount[pool] -} - -func getNumberOfStarts(pool WorkerPool) int { - accessWorkerPoolState.RLock() - defer accessWorkerPoolState.RUnlock() - - return workerpoolStateWorkerStartCount[pool] -} - -func incNumberOfWorkers(pool WorkerPool, count int) { - accessWorkerPoolState.Lock() - defer accessWorkerPoolState.Unlock() - - workerpoolStateWorkerCount[pool] += count - workerpoolStateWorkerStartCount[pool] += count -} - -func decNumberOfWorkers(pool WorkerPool, count int) { - accessWorkerPoolState.Lock() - defer accessWorkerPoolState.Unlock() - - workerpoolStateWorkerCount[pool] -= count + expectedNumberOfStoppedWorkers := increased + assert.Eventually(t, func() bool { + return expectedNumberOfStoppedWorkers == len(exec.StoppedCalls()) + }, time.Millisecond*5000, time.Millisecond*50) } -var ( - workerpoolStateWorkerCount = make(map[WorkerPool]int) - workerpoolStateWorkerStartCount = make(map[WorkerPool]int) - accessWorkerPoolState = &sync.RWMutex{} -) - func ExampleWorkerPool_Blocking() { - pool := New(1, 1) + pool := New(1) defer pool.Stop() + pool.Grow(context.Background(), 1) var state int64 job := func() { atomic.AddInt64(&state, 19) } - pool.Blocking(job) + _ = pool.Blocking(context.Background(), job) fmt.Println(atomic.LoadInt64(&state)) @@ -348,8 +268,9 @@ func ExampleWorkerPool_Blocking() { } func ExampleWorkerPool_SemiBlocking() { - pool := New(1, 1) + pool := New(1) defer pool.Stop() + pool.Grow(context.Background(), 1) var state int64 jobDone := make(chan struct{}) @@ -358,7 +279,7 @@ func ExampleWorkerPool_SemiBlocking() { atomic.AddInt64(&state, 19) } - pool.SemiBlocking(job) + _ = pool.SemiBlocking(context.Background(), job) <-jobDone fmt.Println(state) @@ -369,16 +290,16 @@ func ExampleWorkerPool_SemiBlocking() { func ExampleWorkerPool_Grow() { const n = 19 - pool := New(10, 1) + pool := New(1) defer pool.Stop() - pool.Grow(3) // spin up three new workers + pool.Grow(context.Background(), 3) // spin up three new workers var state int64 wg := &sync.WaitGroup{} wg.Add(n) for i := 0; i < n; i++ { - pool.SemiBlocking(func() { defer wg.Done(); atomic.AddInt64(&state, 1) }) + _ = pool.SemiBlocking(context.Background(), func() { defer wg.Done(); atomic.AddInt64(&state, 1) }) } wg.Wait()