From 2c63238bc5f6b9e10b537921dfc9749cee2b8681 Mon Sep 17 00:00:00 2001 From: Kevin Gibbs Date: Tue, 24 May 2022 20:29:35 +0000 Subject: [PATCH] Cleanup the symantics of ACS Handler The ACS handler was using a channel to support looping, but the channel was only sent to from the current thread via a go routine that would just write to the channel. The rewrite maintains the behavior of always connecting when disconnected potentially after a retry backoff. However it uses a simple endless for loop rather than reading from channels to mimic this straightforward behavior. The rewrite also adheres to the interface contract on returning the error from the context if it is Done for any reason other than cancelled. --- agent/acs/handler/acs_handler.go | 108 ++++++++++--------------- agent/acs/handler/acs_handler_test.go | 111 ++++++++++++++++++++++++-- 2 files changed, 145 insertions(+), 74 deletions(-) diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index 0c090474c09..aa36795477e 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -203,70 +203,54 @@ func NewSession( // Start starts the session. It'll forever keep trying to connect to ACS unless // the context is cancelled. // -// If the context is cancelled, Start() would return with the error code returned -// by the context. -// If the instance is deregistered, Start() would emit an event to the -// deregister-instance event stream and sets the connection backoff time to 1 hour. +// Returns nil always TODO: consider removing error return value completely func (acsSession *session) Start() error { - // connectToACS channel is used to indicate the intent to connect to ACS - // It's processed by the select loop to connect to ACS - connectToACS := make(chan struct{}, 1) - // This is required to trigger the first connection to ACS. Subsequent - // connections are triggered by the handleACSError() method - connectToACS <- struct{}{} + // Loop continuously until context is closed/cancelled for { - select { - case <-connectToACS: - seelog.Debugf("Received connect to ACS message") - // Start a session with ACS - acsError := acsSession.startSessionOnce() - select { - case <-acsSession.ctx.Done(): - // agent is shutting down, exiting cleanly - return nil - default: - } - // Session with ACS was stopped with some error, start processing the error - isInactiveInstance := isInactiveInstanceError(acsError) - if isInactiveInstance { - // If the instance was deregistered, send an event to the event stream - // for the same - seelog.Debug("Container instance is deregistered, notifying listeners") - err := acsSession.deregisterInstanceEventStream.WriteToEventStream(struct{}{}) - if err != nil { - seelog.Debugf("Failed to write to deregister container instance event stream, err: %v", err) - } - } - if shouldReconnectWithoutBackoff(acsError) { - // If ACS or agent closed the connection, there's no need to backoff, - // reconnect immediately - seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError) - acsSession.backoff.Reset() - sendEmptyMessageOnChannel(connectToACS) - } else { - // Disconnected unexpectedly from ACS, compute backoff duration to - // reconnect - reconnectDelay := acsSession.computeReconnectDelay(isInactiveInstance) - seelog.Infof("Reconnecting to ACS in: %s", reconnectDelay.String()) - waitComplete := acsSession.waitForDuration(reconnectDelay) - if waitComplete { - // If the context was not cancelled and we've waited for the - // wait duration without any errors, send the message to the channel - // to reconnect to ACS - seelog.Info("Done waiting; reconnecting to ACS") - sendEmptyMessageOnChannel(connectToACS) - } else { - // Wait was interrupted. We expect the session to close as canceling - // the session context is the only way to end up here. Print a message - // to indicate the same - seelog.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close") - } + seelog.Debugf("Attempting connect to ACS") + // Start a session with ACS + acsError := acsSession.startSessionOnce() + + // If the session is over check for shutdown first + if err := acsSession.ctx.Err(); err != nil { + return nil + } + + // If ACS closed the connection, reconnect immediately + if shouldReconnectWithoutBackoff(acsError) { + seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError) + acsSession.backoff.Reset() + continue + } + + // Session with ACS was stopped with some error, start processing the error + isInactiveInstance := isInactiveInstanceError(acsError) + if isInactiveInstance { + // If the instance was deregistered, send an event to the event stream + // for the same + seelog.Debug("Container instance is deregistered, notifying listeners") + err := acsSession.deregisterInstanceEventStream.WriteToEventStream(struct{}{}) + if err != nil { + seelog.Debugf("Failed to write to deregister container instance event stream, err: %v", err) } - case <-acsSession.ctx.Done(): - // agent is shutting down, exiting cleanly + } + + // Disconnected unexpectedly from ACS, compute backoff duration to + // reconnect + reconnectDelay := acsSession.computeReconnectDelay(isInactiveInstance) + seelog.Infof("Reconnecting to ACS in: %s", reconnectDelay.String()) + waitComplete := acsSession.waitForDuration(reconnectDelay) + if !waitComplete { + // Wait was interrupted. We expect the session to close as canceling + // the session context is the only way to end up here. Print a message + // to indicate the same + seelog.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close") return nil } + // If the context was not cancelled and we've waited for the + // wait duration without any errors, reconnect to ACS + seelog.Info("Done waiting; reconnecting to ACS") } } @@ -592,11 +576,3 @@ func shouldReconnectWithoutBackoff(acsError error) bool { func isInactiveInstanceError(acsError error) bool { return acsError != nil && strings.HasPrefix(acsError.Error(), inactiveInstanceExceptionPrefix) } - -// sendEmptyMessageOnChannel sends an empty message using a go-routine on the -// specified channel -func sendEmptyMessageOnChannel(channel chan<- struct{}) { - go func() { - channel <- struct{}{} - }() -} diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index 6a1ceab6f8a..259591353ff 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -281,9 +281,6 @@ func TestComputeReconnectDelayForActiveInstance(t *testing.T) { // waitForDurationOrCancelledSession method behaves correctly when the session context // is not cancelled func TestWaitForDurationReturnsTrueWhenContextNotCancelled(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -300,9 +297,6 @@ func TestWaitForDurationReturnsTrueWhenContextNotCancelled(t *testing.T) { // waitForDurationOrCancelledSession method behaves correctly when the session contexnt // is cancelled func TestWaitForDurationReturnsFalseWhenContextCancelled(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - ctx, cancel := context.WithCancel(context.Background()) acsSession := session{ ctx: ctx, @@ -532,7 +526,8 @@ func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) { taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - deregisterInstanceEventStream.StartListening() + // Don't start to ensure an error doesn't affect reconnect + // deregisterInstanceEventStream.StartListening() mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() @@ -695,7 +690,107 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) { go func() { sessionError <- acsSession.Start() }() - <-sessionError + response := <-sessionError + assert.Nil(t, response) +} + +// TestHandlerStopsWhenContextIsError tests if the session's Start() method returns +// when session context is in error +func TestHandlerStopsWhenContextIsError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskEngine := mock_engine.NewMockTaskEngine(ctrl) + taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() + + ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) + taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Serve().Do(func() { + time.Sleep(5 * time.Millisecond) + }).Return(io.EOF).AnyTimes() + + acsSession := session{ + containerInstanceARN: "myArn", + credentialsProvider: testCreds, + agentConfig: testConfig, + taskEngine: taskEngine, + ecsClient: ecsClient, + dataClient: data.NewNoopClient(), + taskHandler: taskHandler, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), + ctx: ctx, + cancel: cancel, + resources: &mockSessionResources{mockWsClient}, + _heartbeatTimeout: 20 * time.Millisecond, + _heartbeatJitter: 10 * time.Millisecond, + } + + // The session error channel would have an event when the Start() method returns + // Cancelling the context should trigger this + sessionError := make(chan error) + go func() { + sessionError <- acsSession.Start() + }() + response := <-sessionError + assert.Nil(t, response) +} + +// TestHandlerStopsWhenContextIsErrorReconnectDelay tests if the session's Start() method returns +// when session context is in error +func TestHandlerStopsWhenContextIsErrorReconnectDelay(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskEngine := mock_engine.NewMockTaskEngine(ctrl) + taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() + + ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) + taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Serve().Return(errors.New("InactiveInstanceException")).AnyTimes() + + acsSession := session{ + containerInstanceARN: "myArn", + credentialsProvider: testCreds, + agentConfig: testConfig, + taskEngine: taskEngine, + ecsClient: ecsClient, + dataClient: data.NewNoopClient(), + taskHandler: taskHandler, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), + ctx: ctx, + cancel: cancel, + resources: &mockSessionResources{mockWsClient}, + _heartbeatTimeout: 20 * time.Millisecond, + _heartbeatJitter: 10 * time.Millisecond, + _inactiveInstanceReconnectDelay: 1 * time.Hour, + } + + // The session error channel would have an event when the Start() method returns + // Cancelling the context should trigger this + sessionError := make(chan error) + go func() { + sessionError <- acsSession.Start() + }() + response := <-sessionError + assert.Nil(t, response) } // TestHandlerReconnectsOnDiscoverPollEndpointError tests if handler retries