diff --git a/agent/app/agent.go b/agent/app/agent.go index 0dc86278ebe..c5720811c03 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -87,10 +87,10 @@ const ( targetLifecycleBackoffMax = time.Second * 5 targetLifecycleBackoffJitter = 0.2 targetLifecycleBackoffMultiple = 1.3 - targetLifecycleMaxRetryCount = 3 + targetLifecycleMaxRetryCount = 5 inServiceState = "InService" - asgLifeCyclePollWait = time.Minute - asgLifeCyclePollMax = 120 // given each poll cycle waits for about a minute, this gives 2-3 hours before timing out + asgLifecyclePollWait = time.Minute + asgLifecyclePollMax = 120 // given each poll cycle waits for about a minute, this gives 2-3 hours before timing out ) var ( @@ -294,7 +294,7 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre // If part of ASG, wait until instance is being set up to go in service before registering with cluster if agent.cfg.WarmPoolsSupport.Enabled() { - err := agent.waitUntilInstanceInService(asgLifeCyclePollWait, asgLifeCyclePollMax) + err := agent.waitUntilInstanceInService(asgLifecyclePollWait, asgLifecyclePollMax, targetLifecycleMaxRetryCount) if err != nil && err.Error() != blackholed { seelog.Criticalf("Could not determine target lifecycle of instance: %v", err) return exitcodes.ExitTerminal @@ -397,45 +397,60 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre // waitUntilInstanceInService Polls IMDS until the target lifecycle state indicates that the instance is going in // service. This is to avoid instances going to a warm pool being registered as container instances with the cluster -func (agent *ecsAgent) waitUntilInstanceInService(pollWaitDuration time.Duration, pollMaxTimes int) error { +func (agent *ecsAgent) waitUntilInstanceInService(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) error { var err error var targetState string - // Poll while the instance is in a warmed state or while waiting for the data to be populated. - // If the data is not populated after a certain number of polls, then stop polling and return the not found error. - // The polling maximum does not apply to instances in the warmed states - for i := 0; i < pollMaxTimes || targetState != ""; i++ { - targetState, err = agent.getTargetLifeCycle() - // stop polling if the retrieved state is in service or we get an unexpected error - if targetState == inServiceState { - break - } + // Poll until a target lifecycle state is obtained from IMDS, or an unexpected error occuurs + targetState, err = agent.pollUntilTargetLifecyclePresent(pollWaitDuration, pollMaxTimes, maxRetries) + if err != nil { + return err + } + // Poll while the instance is in a warmed state until it is going to go into service + for targetState != "InService" { + time.Sleep(pollWaitDuration) + targetState, err = agent.getTargetLifecycle(maxRetries) if err != nil { - var statusCode int - if reqErr, ok := err.(awserr.RequestFailure); ok { - statusCode = reqErr.StatusCode() - } - if statusCode != 404 { - break + // Do not exit if error is due to throttling or temporary server errors + // These are likely transient, as at this point IMDS has been successfully queried for state + switch utils.GetRequestFailureStatusCode(err) { + case 429, 500, 502, 503, 504: + seelog.Warnf("Encountered error while waiting for warmed instance to go in service: %v", err) + default: + return err } } - time.Sleep(pollWaitDuration) } return err } +// pollUntilTargetLifecyclePresent polls until obtains a target state or receives an unexpected error +func (agent *ecsAgent) pollUntilTargetLifecyclePresent(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) (string, error) { + var err error + var targetState string + for i := 0; i < pollMaxTimes; i++ { + targetState, err = agent.getTargetLifecycle(maxRetries) + if targetState != "" || + (err != nil && utils.GetRequestFailureStatusCode(err) != 404) { + break + } + time.Sleep(pollWaitDuration) + } + return targetState, err +} + // getTargetLifecycle obtains the target lifecycle state for the instance from IMDS. This is populated for instances // associated with an ASG -func (agent *ecsAgent) getTargetLifeCycle() (string, error) { +func (agent *ecsAgent) getTargetLifecycle(maxRetries int) (string, error) { var targetState string var err error backoff := retry.NewExponentialBackoff(targetLifecycleBackoffMin, targetLifecycleBackoffMax, targetLifecycleBackoffJitter, targetLifecycleBackoffMultiple) - for i := 0; i < targetLifecycleMaxRetryCount; i++ { + for i := 0; i < maxRetries; i++ { targetState, err = agent.ec2MetadataClient.TargetLifecycleState() if err == nil { break } seelog.Debugf("Error when getting intended lifecycle state: %v", err) - if i < targetLifecycleMaxRetryCount { + if i < maxRetries { time.Sleep(backoff.Duration()) } } diff --git a/agent/app/agent_test.go b/agent/app/agent_test.go index 6221462987e..2e10a6fd12f 100644 --- a/agent/app/agent_test.go +++ b/agent/app/agent_test.go @@ -60,16 +60,19 @@ import ( ) const ( - clusterName = "some-cluster" - containerInstanceARN = "container-instance1" - availabilityZone = "us-west-2b" - hostPrivateIPv4Address = "127.0.0.1" - hostPublicIPv4Address = "127.0.0.1" - instanceID = "i-123" - warmedState = "Warmed:Pending" + clusterName = "some-cluster" + containerInstanceARN = "container-instance1" + availabilityZone = "us-west-2b" + hostPrivateIPv4Address = "127.0.0.1" + hostPublicIPv4Address = "127.0.0.1" + instanceID = "i-123" + warmedState = "Warmed:Running" + testTargetLifecycleMaxRetryCount = 1 ) var notFoundErr = awserr.NewRequestFailure(awserr.Error(awserr.New("NotFound", "", errors.New(""))), 404, "") +var badReqErr = awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, "") +var serverErr = awserr.NewRequestFailure(awserr.Error(awserr.New("InternalServerError", "", errors.New(""))), 500, "") var apiVersions = []dockerclient.DockerVersion{ dockerclient.Version_1_21, dockerclient.Version_1_22, @@ -310,7 +313,7 @@ func TestDoStartWarmPoolsError(t *testing.T) { } err := errors.New("error") - mockEC2Metadata.EXPECT().TargetLifecycleState().Return("", err).Times(3) + mockEC2Metadata.EXPECT().TargetLifecycleState().Return("", err).Times(targetLifecycleMaxRetryCount) exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), credentialsManager, state, imageManager, client, execCmdMgr) @@ -345,7 +348,7 @@ func testDoStartHappyPathWithConditions(t *testing.T, blackholed bool, warmPools if blackholed { if warmPoolsEnv { - ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("blackholed")).Times(3) + ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("blackholed")).Times(targetLifecycleMaxRetryCount) } ec2MetadataClient.EXPECT().InstanceID().Return("", errors.New("blackholed")) } else { @@ -1534,17 +1537,30 @@ func newTestDataClient(t *testing.T) (data.Client, func()) { return testClient, cleanup } +type targetLifecycleFuncDetail struct { + val string + err error + returnTimes int +} + func TestWaitUntilInstanceInServicePolling(t *testing.T) { + warmedResult := targetLifecycleFuncDetail{warmedState, nil, 1} + inServiceResult := targetLifecycleFuncDetail{inServiceState, nil, 1} + notFoundErrResult := targetLifecycleFuncDetail{"", notFoundErr, testTargetLifecycleMaxRetryCount} + unexpectedErrResult := targetLifecycleFuncDetail{"", badReqErr, testTargetLifecycleMaxRetryCount} + serverErrResult := targetLifecycleFuncDetail{"", serverErr, testTargetLifecycleMaxRetryCount} testCases := []struct { - name string - states []string - err error - returnsState bool - maxPolls int + name string + funcTestDetails []targetLifecycleFuncDetail + result error + maxPolls int }{ - {"TestWaitUntilInstanceInServicePollsWarmed", []string{warmedState, inServiceState}, nil, true, asgLifeCyclePollMax}, - {"TestWaitUntilInstanceInServicePollsMissing", []string{inServiceState}, notFoundErr, true, asgLifeCyclePollMax}, - {"TestWaitUntilInstanceInServicePollingMaxReached", nil, notFoundErr, false, 1}, + {"TestWaitUntilInServicePollWarmed", []targetLifecycleFuncDetail{warmedResult, warmedResult, inServiceResult}, nil, asgLifecyclePollMax}, + {"TestWaitUntilInServicePollMissing", []targetLifecycleFuncDetail{notFoundErrResult, inServiceResult}, nil, asgLifecyclePollMax}, + {"TestWaitUntilInServiceErrPollMaxReached", []targetLifecycleFuncDetail{notFoundErrResult}, notFoundErr, 1}, + {"TestWaitUntilInServiceNoStateUnexpectedErr", []targetLifecycleFuncDetail{unexpectedErrResult}, badReqErr, asgLifecyclePollMax}, + {"TestWaitUntilInServiceUnexpectedErr", []targetLifecycleFuncDetail{warmedResult, unexpectedErrResult}, badReqErr, asgLifecyclePollMax}, + {"TestWaitUntilInServiceServerErrContinue", []targetLifecycleFuncDetail{warmedResult, serverErrResult, inServiceResult}, nil, asgLifecyclePollMax}, } for _, tc := range testCases { @@ -1555,20 +1571,10 @@ func TestWaitUntilInstanceInServicePolling(t *testing.T) { cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl) agent := &ecsAgent{ec2MetadataClient: ec2MetadataClient, cfg: &cfg} - - if tc.err != nil { - ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", tc.err).Times(3) - } - for _, state := range tc.states { - ec2MetadataClient.EXPECT().TargetLifecycleState().Return(state, nil) - } - var expectedResult error - if tc.returnsState { - expectedResult = nil - } else { - expectedResult = tc.err + for _, detail := range tc.funcTestDetails { + ec2MetadataClient.EXPECT().TargetLifecycleState().Return(detail.val, detail.err).Times(detail.returnTimes) } - assert.Equal(t, expectedResult, agent.waitUntilInstanceInService(1*time.Millisecond, tc.maxPolls)) + assert.Equal(t, tc.result, agent.waitUntilInstanceInService(1*time.Millisecond, tc.maxPolls, testTargetLifecycleMaxRetryCount)) }) } } diff --git a/agent/utils/utils.go b/agent/utils/utils.go index 3cbb8ccf2a4..778338d5f3a 100644 --- a/agent/utils/utils.go +++ b/agent/utils/utils.go @@ -156,6 +156,16 @@ func IsAWSErrorCodeEqual(err error, code string) bool { return ok && awsErr.Code() == code } +// GetRequestFailureStatusCode returns the status code from a +// RequestFailure error, or 0 if the error is not of that type +func GetRequestFailureStatusCode(err error) int { + var statusCode int + if reqErr, ok := err.(awserr.RequestFailure); ok { + statusCode = reqErr.StatusCode() + } + return statusCode +} + // MapToTags converts a map to a slice of tags. func MapToTags(tagsMap map[string]string) []*ecs.Tag { tags := make([]*ecs.Tag, 0) diff --git a/agent/utils/utils_test.go b/agent/utils/utils_test.go index ce2057cdfa6..1ba5d9cb1a7 100644 --- a/agent/utils/utils_test.go +++ b/agent/utils/utils_test.go @@ -162,6 +162,31 @@ func TestIsAWSErrorCodeEqual(t *testing.T) { } } +func TestGetRequestFailureStatusCode(t *testing.T) { + testcases := []struct { + name string + err error + res int + }{ + { + name: "TestGetRequestFailureStatusCodeSuccess", + err: awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, ""), + res: 400, + }, + { + name: "TestGetRequestFailureStatusCodeWrongErrType", + err: errors.New("err"), + res: 0, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.res, GetRequestFailureStatusCode(tc.err)) + }) + } +} + func TestMapToTags(t *testing.T) { tagKey1 := "tagKey1" tagKey2 := "tagKey2"