diff --git a/pkg/rungroup/rungroup.go b/pkg/rungroup/rungroup.go index c618928df..b1cd065a2 100644 --- a/pkg/rungroup/rungroup.go +++ b/pkg/rungroup/rungroup.go @@ -6,10 +6,13 @@ package rungroup // timeout. See: https://github.com/kolide/launcher/issues/1205 import ( + "context" "fmt" + "time" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" + "golang.org/x/sync/semaphore" ) type ( @@ -30,6 +33,11 @@ type ( } ) +const ( + interruptTimeout = 5 * time.Second // How long for all actors to return from their `interrupt` function + executeReturnTimeout = 5 * time.Second // After interrupted, how long for all actors to exit their `execute` functions +) + func NewRunGroup(logger log.Logger) *Group { return &Group{ logger: log.With(logger, "component", "run_group"), @@ -66,15 +74,39 @@ func (g *Group) Run() error { defer level.Debug(g.logger).Log("msg", "done shutting down actors", "actor_count", len(g.actors), "initial_err", initialActorErr) // Signal all actors to stop. + numActors := int64(len(g.actors)) + interruptWait := semaphore.NewWeighted(numActors) for _, a := range g.actors { - level.Debug(g.logger).Log("msg", "interrupting actor", "actor", a.name) - a.interrupt(initialActorErr.err) + interruptWait.Acquire(context.Background(), 1) + go func(a rungroupActor) { + defer interruptWait.Release(1) + level.Debug(g.logger).Log("msg", "interrupting actor", "actor", a.name) + a.interrupt(initialActorErr.err) + level.Debug(g.logger).Log("msg", "interrupt complete", "actor", a.name) + }(a) } - // Wait for all actors to stop. + interruptCtx, interruptCancel := context.WithTimeout(context.Background(), interruptTimeout) + defer interruptCancel() + + // Wait for interrupts to complete, but only until we hit our interruptCtx timeout + if err := interruptWait.Acquire(interruptCtx, numActors); err != nil { + level.Debug(g.logger).Log("msg", "timeout waiting for interrupts to complete, proceeding with shutdown", "err", err) + } + + // Wait for all other actors to stop, but only until we hit our executeReturnTimeout + timeoutTimer := time.NewTimer(executeReturnTimeout) + defer timeoutTimer.Stop() for i := 1; i < cap(errors); i++ { - e := <-errors - level.Debug(g.logger).Log("msg", "successfully interrupted actor", "actor", e.errorSourceName, "index", i) + select { + case <-timeoutTimer.C: + level.Debug(g.logger).Log("msg", "rungroup shutdown deadline exceeded, not waiting for any more actors to return") + + // Return the original error so we can proceed with shutdown + return initialActorErr.err + case e := <-errors: + level.Debug(g.logger).Log("msg", "execute returned", "actor", e.errorSourceName, "index", i) + } } // Return the original error. diff --git a/pkg/rungroup/rungroup_test.go b/pkg/rungroup/rungroup_test.go index 06eeb0a36..ae7b45935 100644 --- a/pkg/rungroup/rungroup_test.go +++ b/pkg/rungroup/rungroup_test.go @@ -21,7 +21,7 @@ func TestRun_MultipleActors(t *testing.T) { testRunGroup := NewRunGroup(log.NewNopLogger()) - groupReceivedInterrupts := make(chan struct{}) + groupReceivedInterrupts := make(chan struct{}, 3) // First actor waits for interrupt and alerts groupReceivedInterrupts when it's interrupted firstActorInterrupt := make(chan struct{}) @@ -52,23 +52,189 @@ func TestRun_MultipleActors(t *testing.T) { anotherActorInterrupt <- struct{}{} }) + runCompleted := make(chan struct{}) go func() { err := testRunGroup.Run() + runCompleted <- struct{}{} require.Error(t, err) }() + // 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer + runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second + interruptCheckTimer := time.NewTicker(runDuration) + defer interruptCheckTimer.Stop() + + receivedInterrupts := 0 + gotRunCompleted := false + for { + if gotRunCompleted { + break + } + + select { + case <-groupReceivedInterrupts: + receivedInterrupts += 1 + case <-runCompleted: + gotRunCompleted = true + case <-interruptCheckTimer.C: + t.Errorf("did not receive expected interrupts within reasonable time, got %d", receivedInterrupts) + t.FailNow() + } + } + + require.True(t, gotRunCompleted, "rungroup.Run did not terminate within time limit") + + require.Equal(t, 3, receivedInterrupts) +} + +func TestRun_MultipleActors_InterruptTimeout(t *testing.T) { + t.Parallel() + + testRunGroup := NewRunGroup(log.NewNopLogger()) + + groupReceivedInterrupts := make(chan struct{}, 3) + + // First actor waits for interrupt and alerts groupReceivedInterrupts when it's interrupted + firstActorInterrupt := make(chan struct{}) + testRunGroup.Add("firstActor", func() error { + <-firstActorInterrupt + return nil + }, func(error) { + groupReceivedInterrupts <- struct{}{} + firstActorInterrupt <- struct{}{} + }) + + // Second actor returns error on `execute`, and then alerts groupReceivedInterrupts when it's interrupted + expectedError := errors.New("test error from interruptingActor") + testRunGroup.Add("interruptingActor", func() error { + time.Sleep(1 * time.Second) + return expectedError + }, func(error) { + groupReceivedInterrupts <- struct{}{} + }) + + // Third actor blocks in interrupt for longer than the interrupt timeout + blockingActorInterrupt := make(chan struct{}) + testRunGroup.Add("blockingActor", func() error { + <-blockingActorInterrupt + return nil + }, func(error) { + time.Sleep(4 * interruptTimeout) + groupReceivedInterrupts <- struct{}{} + blockingActorInterrupt <- struct{}{} + }) + + runCompleted := make(chan struct{}) + go func() { + err := testRunGroup.Run() + require.Error(t, err) + runCompleted <- struct{}{} + }() + + // 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer + runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second + interruptCheckTimer := time.NewTicker(runDuration) + defer interruptCheckTimer.Stop() + receivedInterrupts := 0 + gotRunCompleted := false for { - if receivedInterrupts >= 3 { + if gotRunCompleted { break } + + select { + case <-groupReceivedInterrupts: + receivedInterrupts += 1 + case <-runCompleted: + gotRunCompleted = true + case <-interruptCheckTimer.C: + t.Errorf("did not receive expected interrupts within reasonable time, got %d", receivedInterrupts) + t.FailNow() + } + } + + require.True(t, gotRunCompleted, "rungroup.Run did not terminate within time limit") + + // We only want two interrupts -- we should not be waiting on the blocking actor + require.Equal(t, 2, receivedInterrupts) +} + +func TestRun_MultipleActors_ExecuteReturnTimeout(t *testing.T) { + t.Parallel() + + testRunGroup := NewRunGroup(log.NewNopLogger()) + + groupReceivedInterrupts := make(chan struct{}, 3) + // Keep track of when `execute`s return so we give testRunGroup.Run enough time to do its thing + groupReceivedExecuteReturns := make(chan struct{}, 2) + + // First actor waits for interrupt and alerts groupReceivedInterrupts when it's interrupted + firstActorInterrupt := make(chan struct{}) + testRunGroup.Add("firstActor", func() error { + <-firstActorInterrupt + groupReceivedExecuteReturns <- struct{}{} + return nil + }, func(error) { + groupReceivedInterrupts <- struct{}{} + firstActorInterrupt <- struct{}{} + }) + + // Second actor returns error on `execute`, and then alerts groupReceivedInterrupts when it's interrupted + expectedError := errors.New("test error from interruptingActor") + testRunGroup.Add("interruptingActor", func() error { + time.Sleep(1 * time.Second) + groupReceivedExecuteReturns <- struct{}{} + return expectedError + }, func(error) { + groupReceivedInterrupts <- struct{}{} + }) + + // Third actor never signals to `execute` to return + blockingActorInterrupt := make(chan struct{}) + testRunGroup.Add("blockingActor", func() error { + <-blockingActorInterrupt // will never happen + groupReceivedExecuteReturns <- struct{}{} // will never happen + return nil + }, func(error) { + groupReceivedInterrupts <- struct{}{} + }) + + runCompleted := make(chan struct{}) + go func() { + err := testRunGroup.Run() + runCompleted <- struct{}{} + require.Error(t, err) + }() + + // 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer + runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second + interruptCheckTimer := time.NewTicker(runDuration) + defer interruptCheckTimer.Stop() + + // Make sure all three actors are interrupted, and that two of them terminate their execute + receivedInterrupts := 0 + receivedExecuteReturns := 0 + gotRunCompleted := false + for { + if gotRunCompleted { + break + } + select { case <-groupReceivedInterrupts: receivedInterrupts += 1 - case <-time.After(3 * time.Second): - t.Error("did not receive expected interrupts within reasonable time") + case <-groupReceivedExecuteReturns: + receivedExecuteReturns += 1 + case <-runCompleted: + gotRunCompleted = true + case <-interruptCheckTimer.C: + t.Errorf("did not receive expected interrupts within reasonable time, got %d", receivedInterrupts) + t.FailNow() } } + require.True(t, gotRunCompleted, "rungroup.Run did not terminate within time limit") require.Equal(t, 3, receivedInterrupts) + require.Equal(t, 2, receivedExecuteReturns) }