Skip to content

Commit

Permalink
Cleanup the symantics of ACS Handler
Browse files Browse the repository at this point in the history
The ACS handler was using a channel to support looping, but the
channel was only sent to from the current thread via a go routine
that would just write to the channel.

The rewrite maintains the behavior of always connecting when
disconnected potentially after a retry backoff. However it uses
a simple endless for loop rather than reading from channels to
mimic this straightforward behavior.

The rewrite also adheres to the interface contract on returning
the error from the context if it is Done for any reason other
than cancelled.
  • Loading branch information
aws-gibbskt committed Mar 7, 2023
1 parent 5cf4103 commit 871e3c7
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 74 deletions.
108 changes: 42 additions & 66 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,70 +203,54 @@ func NewSession(
// Start starts the session. It'll forever keep trying to connect to ACS unless
// the context is cancelled.
//
// If the context is cancelled, Start() would return with the error code returned
// by the context.
// If the instance is deregistered, Start() would emit an event to the
// deregister-instance event stream and sets the connection backoff time to 1 hour.
// Returns nil always TODO: consider removing error return value completely
func (acsSession *session) Start() error {
// connectToACS channel is used to indicate the intent to connect to ACS
// It's processed by the select loop to connect to ACS
connectToACS := make(chan struct{}, 1)
// This is required to trigger the first connection to ACS. Subsequent
// connections are triggered by the handleACSError() method
connectToACS <- struct{}{}
// Loop continuously until context is closed/cancelled
for {
select {
case <-connectToACS:
seelog.Debugf("Received connect to ACS message")
// Start a session with ACS
acsError := acsSession.startSessionOnce()
select {
case <-acsSession.ctx.Done():
// agent is shutting down, exiting cleanly
return nil
default:
}
// Session with ACS was stopped with some error, start processing the error
isInactiveInstance := isInactiveInstanceError(acsError)
if isInactiveInstance {
// If the instance was deregistered, send an event to the event stream
// for the same
seelog.Debug("Container instance is deregistered, notifying listeners")
err := acsSession.deregisterInstanceEventStream.WriteToEventStream(struct{}{})
if err != nil {
seelog.Debugf("Failed to write to deregister container instance event stream, err: %v", err)
}
}
if shouldReconnectWithoutBackoff(acsError) {
// If ACS or agent closed the connection, there's no need to backoff,
// reconnect immediately
seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError)
acsSession.backoff.Reset()
sendEmptyMessageOnChannel(connectToACS)
} else {
// Disconnected unexpectedly from ACS, compute backoff duration to
// reconnect
reconnectDelay := acsSession.computeReconnectDelay(isInactiveInstance)
seelog.Infof("Reconnecting to ACS in: %s", reconnectDelay.String())
waitComplete := acsSession.waitForDuration(reconnectDelay)
if waitComplete {
// If the context was not cancelled and we've waited for the
// wait duration without any errors, send the message to the channel
// to reconnect to ACS
seelog.Info("Done waiting; reconnecting to ACS")
sendEmptyMessageOnChannel(connectToACS)
} else {
// Wait was interrupted. We expect the session to close as canceling
// the session context is the only way to end up here. Print a message
// to indicate the same
seelog.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close")
}
seelog.Debugf("Attempting connect to ACS")
// Start a session with ACS
acsError := acsSession.startSessionOnce()

// If the session is over check for shutdown first
if err := acsSession.ctx.Err(); err != nil {
return nil
}

// If ACS closed the connection, reconnect immediately
if shouldReconnectWithoutBackoff(acsError) {
seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError)
acsSession.backoff.Reset()
continue
}

// Session with ACS was stopped with some error, start processing the error
isInactiveInstance := isInactiveInstanceError(acsError)
if isInactiveInstance {
// If the instance was deregistered, send an event to the event stream
// for the same
seelog.Debug("Container instance is deregistered, notifying listeners")
err := acsSession.deregisterInstanceEventStream.WriteToEventStream(struct{}{})
if err != nil {
seelog.Debugf("Failed to write to deregister container instance event stream, err: %v", err)
}
case <-acsSession.ctx.Done():
// agent is shutting down, exiting cleanly
}

// Disconnected unexpectedly from ACS, compute backoff duration to
// reconnect
reconnectDelay := acsSession.computeReconnectDelay(isInactiveInstance)
seelog.Infof("Reconnecting to ACS in: %s", reconnectDelay.String())
waitComplete := acsSession.waitForDuration(reconnectDelay)
if !waitComplete {
// Wait was interrupted. We expect the session to close as canceling
// the session context is the only way to end up here. Print a message
// to indicate the same
seelog.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close")
return nil
}

// If the context was not cancelled and we've waited for the
// wait duration without any errors, reconnect to ACS
seelog.Info("Done waiting; reconnecting to ACS")
}
}

Expand Down Expand Up @@ -592,11 +576,3 @@ func shouldReconnectWithoutBackoff(acsError error) bool {
func isInactiveInstanceError(acsError error) bool {
return acsError != nil && strings.HasPrefix(acsError.Error(), inactiveInstanceExceptionPrefix)
}

// sendEmptyMessageOnChannel sends an empty message using a go-routine on the
// specified channel
func sendEmptyMessageOnChannel(channel chan<- struct{}) {
go func() {
channel <- struct{}{}
}()
}
109 changes: 101 additions & 8 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,6 @@ func TestComputeReconnectDelayForActiveInstance(t *testing.T) {
// waitForDurationOrCancelledSession method behaves correctly when the session context
// is not cancelled
func TestWaitForDurationReturnsTrueWhenContextNotCancelled(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -300,9 +297,6 @@ func TestWaitForDurationReturnsTrueWhenContextNotCancelled(t *testing.T) {
// waitForDurationOrCancelledSession method behaves correctly when the session contexnt
// is cancelled
func TestWaitForDurationReturnsFalseWhenContextCancelled(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
acsSession := session{
ctx: ctx,
Expand Down Expand Up @@ -532,7 +526,8 @@ func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) {
taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil)

deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx)
deregisterInstanceEventStream.StartListening()
// Don't start to ensure an error doesn't affect reconnect
// deregisterInstanceEventStream.StartListening()

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
Expand Down Expand Up @@ -695,7 +690,105 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) {
go func() {
sessionError <- acsSession.Start()
}()
<-sessionError
response := <-sessionError
assert.Nil(t, response)
}

// TestHandlerStopsWhenContextIsError tests if the session's Start() method returns
// when session context is in error
func TestHandlerStopsWhenContextIsError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

ecsClient := mock_api.NewMockECSClient(ctrl)
ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes()

ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond)
taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil)

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
mockWsClient.EXPECT().Serve().Do(func() {
time.Sleep(5 * time.Millisecond)
}).Return(io.EOF).AnyTimes()

acsSession := session{
containerInstanceARN: "myArn",
credentialsProvider: testCreds,
agentConfig: testConfig,
taskEngine: taskEngine,
ecsClient: ecsClient,
dataClient: data.NewNoopClient(),
taskHandler: taskHandler,
backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier),
ctx: ctx,
cancel: cancel,
resources: &mockSessionResources{mockWsClient},
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
}

// The session error channel would have an event when the Start() method returns
// Cancelling the context should trigger this
sessionError := make(chan error)
go func() {
sessionError <- acsSession.Start()
}()
response := <-sessionError
assert.Nil(t, response)
}

// TestHandlerStopsWhenContextIsErrorReconnectDelay tests if the session's Start() method returns
// when session context is in error
func TestHandlerStopsWhenContextIsErrorReconnectDelay(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

ecsClient := mock_api.NewMockECSClient(ctrl)
ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes()

ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond)
taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil)

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
mockWsClient.EXPECT().Serve().Return(errors.New("InactiveInstanceException")).AnyTimes()

acsSession := session{
containerInstanceARN: "myArn",
credentialsProvider: testCreds,
agentConfig: testConfig,
taskEngine: taskEngine,
ecsClient: ecsClient,
dataClient: data.NewNoopClient(),
taskHandler: taskHandler,
backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier),
ctx: ctx,
cancel: cancel,
resources: &mockSessionResources{mockWsClient},
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
_inactiveInstanceReconnectDelay: 1 * time.Hour,
}

// The session error channel would have an event when the Start() method returns
// Cancelling the context should trigger this
sessionError := make(chan error)
go func() {
sessionError <- acsSession.Start()
}()
response := <-sessionError
assert.Nil(t, response)
}

// TestHandlerReconnectsOnDiscoverPollEndpointError tests if handler retries
Expand Down

0 comments on commit 871e3c7

Please sign in to comment.