Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup the symantics of ACS Handler #3225

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 42 additions & 66 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,70 +203,54 @@ func NewSession(
// Start starts the session. It'll forever keep trying to connect to ACS unless
// the context is cancelled.
//
// If the context is cancelled, Start() would return with the error code returned
// by the context.
// If the instance is deregistered, Start() would emit an event to the
// deregister-instance event stream and sets the connection backoff time to 1 hour.
// Returns nil always TODO: consider removing error return value completely
func (acsSession *session) Start() error {
// connectToACS channel is used to indicate the intent to connect to ACS
// It's processed by the select loop to connect to ACS
connectToACS := make(chan struct{}, 1)
// This is required to trigger the first connection to ACS. Subsequent
// connections are triggered by the handleACSError() method
connectToACS <- struct{}{}
// Loop continuously until context is closed/cancelled
for {
select {
case <-connectToACS:
seelog.Debugf("Received connect to ACS message")
// Start a session with ACS
acsError := acsSession.startSessionOnce()
select {
case <-acsSession.ctx.Done():
// agent is shutting down, exiting cleanly
return nil
default:
}
// Session with ACS was stopped with some error, start processing the error
isInactiveInstance := isInactiveInstanceError(acsError)
if isInactiveInstance {
// If the instance was deregistered, send an event to the event stream
// for the same
seelog.Debug("Container instance is deregistered, notifying listeners")
err := acsSession.deregisterInstanceEventStream.WriteToEventStream(struct{}{})
if err != nil {
seelog.Debugf("Failed to write to deregister container instance event stream, err: %v", err)
}
}
if shouldReconnectWithoutBackoff(acsError) {
// If ACS or agent closed the connection, there's no need to backoff,
// reconnect immediately
seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError)
acsSession.backoff.Reset()
sendEmptyMessageOnChannel(connectToACS)
} else {
// Disconnected unexpectedly from ACS, compute backoff duration to
// reconnect
reconnectDelay := acsSession.computeReconnectDelay(isInactiveInstance)
seelog.Infof("Reconnecting to ACS in: %s", reconnectDelay.String())
waitComplete := acsSession.waitForDuration(reconnectDelay)
if waitComplete {
// If the context was not cancelled and we've waited for the
// wait duration without any errors, send the message to the channel
// to reconnect to ACS
seelog.Info("Done waiting; reconnecting to ACS")
sendEmptyMessageOnChannel(connectToACS)
} else {
// Wait was interrupted. We expect the session to close as canceling
// the session context is the only way to end up here. Print a message
// to indicate the same
seelog.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close")
}
seelog.Debugf("Attempting connect to ACS")
// Start a session with ACS
acsError := acsSession.startSessionOnce()
aws-gibbskt marked this conversation as resolved.
Show resolved Hide resolved

// If the session is over check for shutdown first
if err := acsSession.ctx.Err(); err != nil {
aws-gibbskt marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

// If ACS closed the connection, reconnect immediately
if shouldReconnectWithoutBackoff(acsError) {
seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError)
acsSession.backoff.Reset()
continue
}

// Session with ACS was stopped with some error, start processing the error
isInactiveInstance := isInactiveInstanceError(acsError)
if isInactiveInstance {
// If the instance was deregistered, send an event to the event stream
// for the same
seelog.Debug("Container instance is deregistered, notifying listeners")
err := acsSession.deregisterInstanceEventStream.WriteToEventStream(struct{}{})
if err != nil {
seelog.Debugf("Failed to write to deregister container instance event stream, err: %v", err)
}
case <-acsSession.ctx.Done():
// agent is shutting down, exiting cleanly
}

// Disconnected unexpectedly from ACS, compute backoff duration to
// reconnect
reconnectDelay := acsSession.computeReconnectDelay(isInactiveInstance)
seelog.Infof("Reconnecting to ACS in: %s", reconnectDelay.String())
waitComplete := acsSession.waitForDuration(reconnectDelay)
if !waitComplete {
// Wait was interrupted. We expect the session to close as canceling
// the session context is the only way to end up here. Print a message
// to indicate the same
seelog.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close")
return nil
}

// If the context was not cancelled and we've waited for the
// wait duration without any errors, reconnect to ACS
seelog.Info("Done waiting; reconnecting to ACS")
}
}

Expand Down Expand Up @@ -592,11 +576,3 @@ func shouldReconnectWithoutBackoff(acsError error) bool {
func isInactiveInstanceError(acsError error) bool {
return acsError != nil && strings.HasPrefix(acsError.Error(), inactiveInstanceExceptionPrefix)
}

// sendEmptyMessageOnChannel sends an empty message using a go-routine on the
// specified channel
func sendEmptyMessageOnChannel(channel chan<- struct{}) {
go func() {
channel <- struct{}{}
}()
}
111 changes: 103 additions & 8 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,6 @@ func TestComputeReconnectDelayForActiveInstance(t *testing.T) {
// waitForDurationOrCancelledSession method behaves correctly when the session context
// is not cancelled
func TestWaitForDurationReturnsTrueWhenContextNotCancelled(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -300,9 +297,6 @@ func TestWaitForDurationReturnsTrueWhenContextNotCancelled(t *testing.T) {
// waitForDurationOrCancelledSession method behaves correctly when the session contexnt
// is cancelled
func TestWaitForDurationReturnsFalseWhenContextCancelled(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
acsSession := session{
ctx: ctx,
Expand Down Expand Up @@ -532,7 +526,8 @@ func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) {
taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil)

deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx)
deregisterInstanceEventStream.StartListening()
// Don't start to ensure an error doesn't affect reconnect
// deregisterInstanceEventStream.StartListening()

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
Expand Down Expand Up @@ -695,7 +690,107 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) {
go func() {
sessionError <- acsSession.Start()
}()
<-sessionError
response := <-sessionError
assert.Nil(t, response)
}

// TestHandlerStopsWhenContextIsError tests if the session's Start() method returns
// when session context is in error
func TestHandlerStopsWhenContextIsError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

ecsClient := mock_api.NewMockECSClient(ctrl)
ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes()

ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond)
taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil)

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
mockWsClient.EXPECT().Serve().Do(func() {
time.Sleep(5 * time.Millisecond)
}).Return(io.EOF).AnyTimes()

acsSession := session{
containerInstanceARN: "myArn",
credentialsProvider: testCreds,
agentConfig: testConfig,
taskEngine: taskEngine,
ecsClient: ecsClient,
dataClient: data.NewNoopClient(),
taskHandler: taskHandler,
backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier),
ctx: ctx,
cancel: cancel,
resources: &mockSessionResources{mockWsClient},
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
}

// The session error channel would have an event when the Start() method returns
// Cancelling the context should trigger this
sessionError := make(chan error)
go func() {
sessionError <- acsSession.Start()
}()
response := <-sessionError
assert.Nil(t, response)
}

// TestHandlerStopsWhenContextIsErrorReconnectDelay tests if the session's Start() method returns
// when session context is in error
func TestHandlerStopsWhenContextIsErrorReconnectDelay(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

ecsClient := mock_api.NewMockECSClient(ctrl)
ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes()

ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond)
taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil)

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
mockWsClient.EXPECT().Serve().Return(errors.New("InactiveInstanceException")).AnyTimes()

acsSession := session{
containerInstanceARN: "myArn",
credentialsProvider: testCreds,
agentConfig: testConfig,
taskEngine: taskEngine,
ecsClient: ecsClient,
dataClient: data.NewNoopClient(),
taskHandler: taskHandler,
backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier),
ctx: ctx,
cancel: cancel,
resources: &mockSessionResources{mockWsClient},
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
_inactiveInstanceReconnectDelay: 1 * time.Hour,
}

// The session error channel would have an event when the Start() method returns
// Cancelling the context should trigger this
sessionError := make(chan error)
go func() {
sessionError <- acsSession.Start()
}()
response := <-sessionError
assert.Nil(t, response)
}

// TestHandlerReconnectsOnDiscoverPollEndpointError tests if handler retries
Expand Down