Skip to content

Commit

Permalink
separated logic for warm pools polling scenarios and do not fail on t…
Browse files Browse the repository at this point in the history
…hrottling or transient server errors once state obtained (aws#3055)

Co-authored-by: Lydia Filipe <fillydia@amazon.com>
  • Loading branch information
lydiafilipe and Lydia Filipe committed Nov 17, 2021
1 parent d7b4c3c commit 1322be5
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 53 deletions.
61 changes: 38 additions & 23 deletions agent/app/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ const (
targetLifecycleBackoffMultiple = 1.3
targetLifecycleMaxRetryCount = 3
inServiceState = "InService"
asgLifeCyclePollWait = time.Minute
asgLifeCyclePollMax = 120 // given each poll cycle waits for about a minute, this gives 2-3 hours before timing out
asgLifecyclePollWait = time.Minute
asgLifecyclePollMax = 120 // given each poll cycle waits for about a minute, this gives 2-3 hours before timing out
)

var (
Expand Down Expand Up @@ -295,7 +295,7 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre

// If part of ASG, wait until instance is being set up to go in service before registering with cluster
if agent.cfg.WarmPoolsSupport.Enabled() {
err := agent.waitUntilInstanceInService(asgLifeCyclePollWait, asgLifeCyclePollMax)
err := agent.waitUntilInstanceInService(asgLifecyclePollWait, asgLifecyclePollMax, targetLifecycleMaxRetryCount)
if err != nil && err.Error() != blackholed {
seelog.Criticalf("Could not determine target lifecycle of instance: %v", err)
return exitcodes.ExitTerminal
Expand Down Expand Up @@ -407,45 +407,60 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre

// waitUntilInstanceInService Polls IMDS until the target lifecycle state indicates that the instance is going in
// service. This is to avoid instances going to a warm pool being registered as container instances with the cluster
func (agent *ecsAgent) waitUntilInstanceInService(pollWaitDuration time.Duration, pollMaxTimes int) error {
func (agent *ecsAgent) waitUntilInstanceInService(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) error {
var err error
var targetState string
// Poll while the instance is in a warmed state or while waiting for the data to be populated.
// If the data is not populated after a certain number of polls, then stop polling and return the not found error.
// The polling maximum does not apply to instances in the warmed states
for i := 0; i < pollMaxTimes || targetState != ""; i++ {
targetState, err = agent.getTargetLifeCycle()
// stop polling if the retrieved state is in service or we get an unexpected error
if targetState == inServiceState {
break
}
// Poll until a target lifecycle state is obtained from IMDS, or an unexpected error occuurs
targetState, err = agent.pollUntilTargetLifecyclePresent(pollWaitDuration, pollMaxTimes, maxRetries)
if err != nil {
return err
}
// Poll while the instance is in a warmed state until it is going to go into service
for targetState != inServiceState {
time.Sleep(pollWaitDuration)
targetState, err = agent.getTargetLifecycle(maxRetries)
if err != nil {
var statusCode int
if reqErr, ok := err.(awserr.RequestFailure); ok {
statusCode = reqErr.StatusCode()
}
if statusCode != 404 {
break
// Do not exit if error is due to throttling or temporary server errors
// These are likely transient, as at this point IMDS has been successfully queried for state
switch utils.GetRequestFailureStatusCode(err) {
case 429, 500, 502, 503, 504:
seelog.Warnf("Encountered error while waiting for warmed instance to go in service: %v", err)
default:
return err
}
}
time.Sleep(pollWaitDuration)
}
return err
}

// pollUntilTargetLifecyclePresent polls until obtains a target state or receives an unexpected error
func (agent *ecsAgent) pollUntilTargetLifecyclePresent(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) (string, error) {
var err error
var targetState string
for i := 0; i < pollMaxTimes; i++ {
targetState, err = agent.getTargetLifecycle(maxRetries)
if targetState != "" ||
(err != nil && utils.GetRequestFailureStatusCode(err) != 404) {
break
}
time.Sleep(pollWaitDuration)
}
return targetState, err
}

// getTargetLifecycle obtains the target lifecycle state for the instance from IMDS. This is populated for instances
// associated with an ASG
func (agent *ecsAgent) getTargetLifeCycle() (string, error) {
func (agent *ecsAgent) getTargetLifecycle(maxRetries int) (string, error) {
var targetState string
var err error
backoff := retry.NewExponentialBackoff(targetLifecycleBackoffMin, targetLifecycleBackoffMax, targetLifecycleBackoffJitter, targetLifecycleBackoffMultiple)
for i := 0; i < targetLifecycleMaxRetryCount; i++ {
for i := 0; i < maxRetries; i++ {
targetState, err = agent.ec2MetadataClient.TargetLifecycleState()
if err == nil {
break
}
seelog.Debugf("Error when getting intended lifecycle state: %v", err)
if i < targetLifecycleMaxRetryCount {
if i < maxRetries {
time.Sleep(backoff.Duration())
}
}
Expand Down
66 changes: 36 additions & 30 deletions agent/app/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,19 @@ import (
)

const (
clusterName = "some-cluster"
containerInstanceARN = "container-instance1"
availabilityZone = "us-west-2b"
hostPrivateIPv4Address = "127.0.0.1"
hostPublicIPv4Address = "127.0.0.1"
instanceID = "i-123"
warmedState = "Warmed:Pending"
clusterName = "some-cluster"
containerInstanceARN = "container-instance1"
availabilityZone = "us-west-2b"
hostPrivateIPv4Address = "127.0.0.1"
hostPublicIPv4Address = "127.0.0.1"
instanceID = "i-123"
warmedState = "Warmed:Running"
testTargetLifecycleMaxRetryCount = 1
)

var notFoundErr = awserr.NewRequestFailure(awserr.Error(awserr.New("NotFound", "", errors.New(""))), 404, "")
var badReqErr = awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, "")
var serverErr = awserr.NewRequestFailure(awserr.Error(awserr.New("InternalServerError", "", errors.New(""))), 500, "")
var apiVersions = []dockerclient.DockerVersion{
dockerclient.Version_1_21,
dockerclient.Version_1_22,
Expand Down Expand Up @@ -310,7 +313,7 @@ func TestDoStartWarmPoolsError(t *testing.T) {
}

err := errors.New("error")
mockEC2Metadata.EXPECT().TargetLifecycleState().Return("", err).Times(3)
mockEC2Metadata.EXPECT().TargetLifecycleState().Return("", err).Times(targetLifecycleMaxRetryCount)

exitCode := agent.doStart(eventstream.NewEventStream("events", ctx),
credentialsManager, state, imageManager, client, execCmdMgr)
Expand Down Expand Up @@ -345,7 +348,7 @@ func testDoStartHappyPathWithConditions(t *testing.T, blackholed bool, warmPools

if blackholed {
if warmPoolsEnv {
ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("blackholed")).Times(3)
ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("blackholed")).Times(targetLifecycleMaxRetryCount)
}
ec2MetadataClient.EXPECT().InstanceID().Return("", errors.New("blackholed"))
} else {
Expand Down Expand Up @@ -1534,17 +1537,30 @@ func newTestDataClient(t *testing.T) (data.Client, func()) {
return testClient, cleanup
}

type targetLifecycleFuncDetail struct {
val string
err error
returnTimes int
}

func TestWaitUntilInstanceInServicePolling(t *testing.T) {
warmedResult := targetLifecycleFuncDetail{warmedState, nil, 1}
inServiceResult := targetLifecycleFuncDetail{inServiceState, nil, 1}
notFoundErrResult := targetLifecycleFuncDetail{"", notFoundErr, testTargetLifecycleMaxRetryCount}
unexpectedErrResult := targetLifecycleFuncDetail{"", badReqErr, testTargetLifecycleMaxRetryCount}
serverErrResult := targetLifecycleFuncDetail{"", serverErr, testTargetLifecycleMaxRetryCount}
testCases := []struct {
name string
states []string
err error
returnsState bool
maxPolls int
name string
funcTestDetails []targetLifecycleFuncDetail
result error
maxPolls int
}{
{"TestWaitUntilInstanceInServicePollsWarmed", []string{warmedState, inServiceState}, nil, true, asgLifeCyclePollMax},
{"TestWaitUntilInstanceInServicePollsMissing", []string{inServiceState}, notFoundErr, true, asgLifeCyclePollMax},
{"TestWaitUntilInstanceInServicePollingMaxReached", nil, notFoundErr, false, 1},
{"TestWaitUntilInServicePollWarmed", []targetLifecycleFuncDetail{warmedResult, warmedResult, inServiceResult}, nil, asgLifecyclePollMax},
{"TestWaitUntilInServicePollMissing", []targetLifecycleFuncDetail{notFoundErrResult, inServiceResult}, nil, asgLifecyclePollMax},
{"TestWaitUntilInServiceErrPollMaxReached", []targetLifecycleFuncDetail{notFoundErrResult}, notFoundErr, 1},
{"TestWaitUntilInServiceNoStateUnexpectedErr", []targetLifecycleFuncDetail{unexpectedErrResult}, badReqErr, asgLifecyclePollMax},
{"TestWaitUntilInServiceUnexpectedErr", []targetLifecycleFuncDetail{warmedResult, unexpectedErrResult}, badReqErr, asgLifecyclePollMax},
{"TestWaitUntilInServiceServerErrContinue", []targetLifecycleFuncDetail{warmedResult, serverErrResult, inServiceResult}, nil, asgLifecyclePollMax},
}

for _, tc := range testCases {
Expand All @@ -1555,20 +1571,10 @@ func TestWaitUntilInstanceInServicePolling(t *testing.T) {
cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled}
ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl)
agent := &ecsAgent{ec2MetadataClient: ec2MetadataClient, cfg: &cfg}

if tc.err != nil {
ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", tc.err).Times(3)
}
for _, state := range tc.states {
ec2MetadataClient.EXPECT().TargetLifecycleState().Return(state, nil)
}
var expectedResult error
if tc.returnsState {
expectedResult = nil
} else {
expectedResult = tc.err
for _, detail := range tc.funcTestDetails {
ec2MetadataClient.EXPECT().TargetLifecycleState().Return(detail.val, detail.err).Times(detail.returnTimes)
}
assert.Equal(t, expectedResult, agent.waitUntilInstanceInService(1*time.Millisecond, tc.maxPolls))
assert.Equal(t, tc.result, agent.waitUntilInstanceInService(1*time.Millisecond, tc.maxPolls, testTargetLifecycleMaxRetryCount))
})
}
}
10 changes: 10 additions & 0 deletions agent/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ func IsAWSErrorCodeEqual(err error, code string) bool {
return ok && awsErr.Code() == code
}

// GetRequestFailureStatusCode returns the status code from a
// RequestFailure error, or 0 if the error is not of that type
func GetRequestFailureStatusCode(err error) int {
var statusCode int
if reqErr, ok := err.(awserr.RequestFailure); ok {
statusCode = reqErr.StatusCode()
}
return statusCode
}

// MapToTags converts a map to a slice of tags.
func MapToTags(tagsMap map[string]string) []*ecs.Tag {
tags := make([]*ecs.Tag, 0)
Expand Down
25 changes: 25 additions & 0 deletions agent/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ func TestIsAWSErrorCodeEqual(t *testing.T) {
}
}

func TestGetRequestFailureStatusCode(t *testing.T) {
testcases := []struct {
name string
err error
res int
}{
{
name: "TestGetRequestFailureStatusCodeSuccess",
err: awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, ""),
res: 400,
},
{
name: "TestGetRequestFailureStatusCodeWrongErrType",
err: errors.New("err"),
res: 0,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.res, GetRequestFailureStatusCode(tc.err))
})
}
}

func TestMapToTags(t *testing.T) {
tagKey1 := "tagKey1"
tagKey2 := "tagKey2"
Expand Down

0 comments on commit 1322be5

Please sign in to comment.