diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b25b87e74b..27ae02577bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 1.59.0 +* Feature - prevent instances in EC2 Autoscaling warm pool from being registered with cluster [#3123](https://github.com/aws/amazon-ecs-agent/pull/3123) +* Enhancement - DiscoverPollEndpoint: lengthen cache ttl and improve resiliency [#3109](https://github.com/aws/amazon-ecs-agent/pull/3109) + ## 1.58.0 * Enhancement - Update agent build go version to 1.17.5 [#3105](https://github.com/aws/amazon-ecs-agent/pull/3105) * Enhancement - bumped pause container gcc build version [#3108](https://github.com/aws/amazon-ecs-agent/pull/3108) diff --git a/README.md b/README.md index 71e4f9c2f5b..b7d68e4ec95 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,8 @@ additional details on each available environment variable. | `ECS_FSX_WINDOWS_FILE_SERVER_SUPPORTED` | `true` | Whether FSx for Windows File Server volume type is supported on the container instance. This variable is only supported on agent versions 1.47.0 and later. | `false` | `true` | | `ECS_ENABLE_RUNTIME_STATS` | `true` | Determines if [pprof](https://pkg.go.dev/net/http/pprof) is enabled for the agent. If enabled, the different profiles can be accessed through the agent's introspection port (e.g. `curl http://localhost:51678/debug/pprof/heap > heap.pprof`). In addition, agent's [runtime stats](https://pkg.go.dev/runtime#ReadMemStats) are logged to `/var/log/ecs/runtime-stats.log` file. | `false` | `false` | | `ECS_EXCLUDE_IPV6_PORTBINDING` | `true` | Determines if agent should exclude IPv6 port binding using default network mode. If enabled, IPv6 port binding will be filtered out, and the response of DescribeTasks API call will not show tasks' IPv6 port bindings, but it is still included in Task metadata endpoint. | `true` | `true` | - +| `ECS_WARM_POOLS_CHECK` | `true` | Whether to ensure instances going into an [EC2 Auto Scaling group warm pool](https://docs.aws.amazon.com/autoscaling/ec2/userguide/ec2-auto-scaling-warm-pools.html) are prevented from being registered with the cluster. Set to true only if using EC2 Autoscaling | `false` | `false` | + ### Persistence When you run the Amazon ECS Container Agent in production, its `datadir` should be persisted between runs of the Docker diff --git a/VERSION b/VERSION index 79f82f6b8e0..bb120e876c6 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.58.0 +1.59.0 diff --git a/agent/api/ecsclient/client.go b/agent/api/ecsclient/client.go index cf3cecd29e0..cb5c357d5a6 100644 --- a/agent/api/ecsclient/client.go +++ b/agent/api/ecsclient/client.go @@ -40,8 +40,7 @@ const ( ecsMaxImageDigestLength = 255 ecsMaxReasonLength = 255 ecsMaxRuntimeIDLength = 255 - pollEndpointCacheSize = 1 - pollEndpointCacheTTL = 20 * time.Minute + pollEndpointCacheTTL = 12 * time.Hour roundtripTimeout = 5 * time.Second azAttrName = "ecs.availability-zone" cpuArchAttrName = "ecs.cpu-architecture" @@ -56,7 +55,7 @@ type APIECSClient struct { standardClient api.ECSSDK submitStateChangeClient api.ECSSubmitStateSDK ec2metadata ec2.EC2MetadataClient - pollEndpoinCache async.Cache + pollEndpointCache async.TTLCache } // NewECSClient creates a new ECSClient interface object @@ -74,14 +73,13 @@ func NewECSClient( } standardClient := ecs.New(session.New(&ecsConfig)) submitStateChangeClient := newSubmitStateChangeClient(&ecsConfig) - pollEndpoinCache := async.NewLRUCache(pollEndpointCacheSize, pollEndpointCacheTTL) return &APIECSClient{ credentialProvider: credentialProvider, config: config, standardClient: standardClient, submitStateChangeClient: submitStateChangeClient, ec2metadata: ec2MetadataClient, - pollEndpoinCache: pollEndpoinCache, + pollEndpointCache: async.NewTTLCache(pollEndpointCacheTTL), } } @@ -585,26 +583,37 @@ func (client *APIECSClient) DiscoverTelemetryEndpoint(containerInstanceArn strin func (client *APIECSClient) discoverPollEndpoint(containerInstanceArn string) (*ecs.DiscoverPollEndpointOutput, error) { // Try getting an entry from the cache - cachedEndpoint, found := client.pollEndpoinCache.Get(containerInstanceArn) - if found { - // Cache hit. Return the output. + cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn) + if !expired && found { + // Cache hit and not expired. Return the output. if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { + seelog.Infof("Using cached DiscoverPollEndpoint. endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s", + aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn) return output, nil } } - // Cache miss, invoke the ECS DiscoverPollEndpoint API. + // Cache miss or expired, invoke the ECS DiscoverPollEndpoint API. seelog.Debugf("Invoking DiscoverPollEndpoint for '%s'", containerInstanceArn) output, err := client.standardClient.DiscoverPollEndpoint(&ecs.DiscoverPollEndpointInput{ ContainerInstance: &containerInstanceArn, Cluster: &client.config.Cluster, }) if err != nil { + // if we got an error calling the API, fallback to an expired cached endpoint if + // we have it. + if expired { + if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { + seelog.Infof("Error calling DiscoverPollEndpoint. Using cached but expired endpoint as a fallback. error=%s endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s", + err, aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn) + return output, nil + } + } return nil, err } // Cache the response from ECS. - client.pollEndpoinCache.Set(containerInstanceArn, output) + client.pollEndpointCache.Set(containerInstanceArn, output) return output, nil } diff --git a/agent/api/ecsclient/client_test.go b/agent/api/ecsclient/client_test.go index 2457c80919d..21fedd0ba82 100644 --- a/agent/api/ecsclient/client_test.go +++ b/agent/api/ecsclient/client_test.go @@ -838,23 +838,23 @@ func TestDiscoverPollEndpointCacheHit(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := mock_async.NewMockCache(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" - pollEndpoinCache.EXPECT().Get("containerInstance").Return( + pollEndpointCache.EXPECT().Get("containerInstance").Return( &ecs.DiscoverPollEndpointOutput{ Endpoint: aws.String(pollEndpoint), - }, true) + }, false, true) output, err := client.discoverPollEndpoint("containerInstance") if err != nil { t.Fatalf("Error in discoverPollEndpoint: %v", err) @@ -869,16 +869,16 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := mock_async.NewMockCache(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ @@ -886,9 +886,44 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) { } gomock.InOrder( - pollEndpoinCache.EXPECT().Get("containerInstance").Return(nil, false), + pollEndpointCache.EXPECT().Get("containerInstance").Return(nil, false, false), mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(pollEndpointOutput, nil), - pollEndpoinCache.EXPECT().Set("containerInstance", pollEndpointOutput), + pollEndpointCache.EXPECT().Set("containerInstance", pollEndpointOutput), + ) + + output, err := client.discoverPollEndpoint("containerInstance") + if err != nil { + t.Fatalf("Error in discoverPollEndpoint: %v", err) + } + if aws.StringValue(output.Endpoint) != pollEndpoint { + t.Errorf("Mismatch in poll endpoint: %s != %s", aws.StringValue(output.Endpoint), pollEndpoint) + } +} + +func TestDiscoverPollEndpointExpiredButDPEFailed(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockSDK := mock_api.NewMockECSSDK(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) + client := &APIECSClient{ + credentialProvider: credentials.AnonymousCredentials, + config: &config.Config{ + Cluster: configuredCluster, + AWSRegion: "us-east-1", + }, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, + } + pollEndpoint := "http://127.0.0.1" + pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ + Endpoint: &pollEndpoint, + } + + gomock.InOrder( + pollEndpointCache.EXPECT().Get("containerInstance").Return(pollEndpointOutput, true, false), + mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("error!")), ) output, err := client.discoverPollEndpoint("containerInstance") @@ -905,16 +940,16 @@ func TestDiscoverTelemetryEndpointAfterPollEndpointCacheHit(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := async.NewLRUCache(1, 10*time.Minute) + pollEndpointCache := async.NewTTLCache(10 * time.Minute) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" diff --git a/agent/app/agent.go b/agent/app/agent.go index 4505c0ab917..bfe71daa52a 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -83,6 +83,15 @@ const ( instanceIdBackoffJitter = 0.2 instanceIdBackoffMultiple = 1.3 instanceIdMaxRetryCount = 3 + + targetLifecycleBackoffMin = time.Second + targetLifecycleBackoffMax = time.Second * 5 + targetLifecycleBackoffJitter = 0.2 + targetLifecycleBackoffMultiple = 1.3 + targetLifecycleMaxRetryCount = 3 + inServiceState = "InService" + asgLifecyclePollWait = time.Minute + asgLifecyclePollMax = 120 // given each poll cycle waits for about a minute, this gives 2-3 hours before timing out ) var ( @@ -291,6 +300,19 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre seelog.Criticalf("Unable to initialize new task engine: %v", err) return exitcodes.ExitTerminal } + + // Start termination handler in goroutine + go agent.terminationHandler(state, agent.dataClient, taskEngine, agent.cancel) + + // If part of ASG, wait until instance is being set up to go in service before registering with cluster + if agent.cfg.WarmPoolsSupport.Enabled() { + err := agent.waitUntilInstanceInService(asgLifecyclePollWait, asgLifecyclePollMax, targetLifecycleMaxRetryCount) + if err != nil && err.Error() != blackholed { + seelog.Criticalf("Could not determine target lifecycle of instance: %v", err) + return exitcodes.ExitTerminal + } + } + agent.initMetricsEngine() loadPauseErr := agent.loadPauseContainer() @@ -387,6 +409,70 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre deregisterInstanceEventStream, client, state, taskHandler, doctor) } +// waitUntilInstanceInService Polls IMDS until the target lifecycle state indicates that the instance is going in +// service. This is to avoid instances going to a warm pool being registered as container instances with the cluster +func (agent *ecsAgent) waitUntilInstanceInService(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) error { + seelog.Info("Waiting for instance to go InService") + var err error + var targetState string + // Poll until a target lifecycle state is obtained from IMDS, or an unexpected error occurs + targetState, err = agent.pollUntilTargetLifecyclePresent(pollWaitDuration, pollMaxTimes, maxRetries) + if err != nil { + return err + } + // Poll while the instance is in a warmed state until it is going to go into service + for targetState != inServiceState { + time.Sleep(pollWaitDuration) + targetState, err = agent.getTargetLifecycle(maxRetries) + if err != nil { + // Do not exit if error is due to throttling or temporary server errors + // These are likely transient, as at this point IMDS has been successfully queried for state + switch utils.GetRequestFailureStatusCode(err) { + case 429, 500, 502, 503, 504: + seelog.Warnf("Encountered error while waiting for warmed instance to go in service: %v", err) + default: + return err + } + } + } + return err +} + +// pollUntilTargetLifecyclePresent polls until obtains a target state or receives an unexpected error +func (agent *ecsAgent) pollUntilTargetLifecyclePresent(pollWaitDuration time.Duration, pollMaxTimes int, maxRetries int) (string, error) { + var err error + var targetState string + for i := 0; i < pollMaxTimes; i++ { + targetState, err = agent.getTargetLifecycle(maxRetries) + if targetState != "" || + (err != nil && utils.GetRequestFailureStatusCode(err) != 404) { + break + } + time.Sleep(pollWaitDuration) + } + return targetState, err +} + +// getTargetLifecycle obtains the target lifecycle state for the instance from IMDS. This is populated for instances +// associated with an ASG +func (agent *ecsAgent) getTargetLifecycle(maxRetries int) (string, error) { + var targetState string + var err error + backoff := retry.NewExponentialBackoff(targetLifecycleBackoffMin, targetLifecycleBackoffMax, targetLifecycleBackoffJitter, targetLifecycleBackoffMultiple) + for i := 0; i < maxRetries; i++ { + targetState, err = agent.ec2MetadataClient.TargetLifecycleState() + if err == nil { + break + } + seelog.Debugf("Error when getting intended lifecycle state: %v", err) + if i < maxRetries { + time.Sleep(backoff.Duration()) + } + } + seelog.Debugf("Target lifecycle state of instance: %v", targetState) + return targetState, err +} + // newTaskEngine creates a new docker task engine object. It tries to load the // local state if needed, else initializes a new one func (agent *ecsAgent) newTaskEngine(containerChangeEventStream *eventstream.EventStream, @@ -687,8 +773,6 @@ func (agent *ecsAgent) startAsyncRoutines( go agent.startSpotInstanceDrainingPoller(agent.ctx, client) } - go agent.terminationHandler(state, agent.dataClient, taskEngine, agent.cancel) - // Agent introspection api go handlers.ServeIntrospectionHTTPEndpoint(agent.ctx, &agent.containerInstanceARN, taskEngine, agent.cfg) diff --git a/agent/app/agent_test.go b/agent/app/agent_test.go index e12cecad3b7..79210cb8569 100644 --- a/agent/app/agent_test.go +++ b/agent/app/agent_test.go @@ -24,6 +24,7 @@ import ( "sort" "sync" "testing" + "time" apierrors "github.com/aws/amazon-ecs-agent/agent/api/errors" mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" @@ -50,7 +51,6 @@ import ( mock_statemanager "github.com/aws/amazon-ecs-agent/agent/statemanager/mocks" mock_mobypkgwrapper "github.com/aws/amazon-ecs-agent/agent/utils/mobypkgwrapper/mocks" "github.com/aws/amazon-ecs-agent/agent/version" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" aws_credentials "github.com/aws/aws-sdk-go/aws/credentials" @@ -60,14 +60,19 @@ import ( ) const ( - clusterName = "some-cluster" - containerInstanceARN = "container-instance1" - availabilityZone = "us-west-2b" - hostPrivateIPv4Address = "127.0.0.1" - hostPublicIPv4Address = "127.0.0.1" - instanceID = "i-123" + clusterName = "some-cluster" + containerInstanceARN = "container-instance1" + availabilityZone = "us-west-2b" + hostPrivateIPv4Address = "127.0.0.1" + hostPublicIPv4Address = "127.0.0.1" + instanceID = "i-123" + warmedState = "Warmed:Running" + testTargetLifecycleMaxRetryCount = 1 ) +var notFoundErr = awserr.NewRequestFailure(awserr.Error(awserr.New("NotFound", "", errors.New(""))), 404, "") +var badReqErr = awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, "") +var serverErr = awserr.NewRequestFailure(awserr.Error(awserr.New("InternalServerError", "", errors.New(""))), 500, "") var apiVersions = []dockerclient.DockerVersion{ dockerclient.Version_1_21, dockerclient.Version_1_22, @@ -235,6 +240,8 @@ func TestDoStartRegisterContainerInstanceErrorTerminal(t *testing.T) { dockerClient: dockerClient, mobyPlugins: mockMobyPlugins, ec2MetadataClient: mockEC2Metadata, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + }, } exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), @@ -279,6 +286,8 @@ func TestDoStartRegisterContainerInstanceErrorNonTerminal(t *testing.T) { credentialProvider: aws_credentials.NewCredentials(mockCredentialsProvider), mobyPlugins: mockMobyPlugins, ec2MetadataClient: mockEC2Metadata, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + }, } exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), @@ -286,7 +295,60 @@ func TestDoStartRegisterContainerInstanceErrorNonTerminal(t *testing.T) { assert.Equal(t, exitcodes.ExitError, exitCode) } +func TestDoStartWarmPoolsError(t *testing.T) { + ctrl, credentialsManager, state, imageManager, client, + dockerClient, _, _, execCmdMgr := setup(t) + defer ctrl.Finish() + mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) + gomock.InOrder( + dockerClient.EXPECT().SupportedVersions().Return(apiVersions), + ) + + cfg := getTestConfig() + cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + ctx, cancel := context.WithCancel(context.TODO()) + // Cancel the context to cancel async routines + defer cancel() + terminationHandlerChan := make(chan bool) + terminationHandlerInvoked := false + agent := &ecsAgent{ + ctx: ctx, + cfg: &cfg, + dockerClient: dockerClient, + ec2MetadataClient: mockEC2Metadata, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + terminationHandlerChan <- true + }, + } + + err := errors.New("error") + mockEC2Metadata.EXPECT().TargetLifecycleState().Return("", err).Times(targetLifecycleMaxRetryCount) + + exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), + credentialsManager, state, imageManager, client, execCmdMgr) + + select { + case terminationHandlerInvoked = <-terminationHandlerChan: + case <-time.After(10 * time.Second): + } + assert.Equal(t, exitcodes.ExitTerminal, exitCode) + // verify that termination handler had been started before pollling + assert.True(t, terminationHandlerInvoked) +} + func TestDoStartHappyPath(t *testing.T) { + testDoStartHappyPathWithConditions(t, false, false) +} + +func TestDoStartWarmPoolsEnabled(t *testing.T) { + testDoStartHappyPathWithConditions(t, false, true) +} + +func TestDoStartWarmPoolsBlackholed(t *testing.T) { + testDoStartHappyPathWithConditions(t, true, true) +} + +func testDoStartHappyPathWithConditions(t *testing.T, blackholed bool, warmPoolsEnv bool) { ctrl, credentialsManager, _, imageManager, client, dockerClient, stateManagerFactory, saveableOptionFactory, execCmdMgr := setup(t) defer ctrl.Finish() @@ -299,7 +361,19 @@ func TestDoStartHappyPath(t *testing.T) { ec2MetadataClient.EXPECT().PrivateIPv4Address().Return(hostPrivateIPv4Address, nil) ec2MetadataClient.EXPECT().PublicIPv4Address().Return(hostPublicIPv4Address, nil) ec2MetadataClient.EXPECT().OutpostARN().Return("", nil) - ec2MetadataClient.EXPECT().InstanceID().Return(instanceID, nil) + + if blackholed { + if warmPoolsEnv { + ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("blackholed")).Times(targetLifecycleMaxRetryCount) + } + ec2MetadataClient.EXPECT().InstanceID().Return("", errors.New("blackholed")) + } else { + if warmPoolsEnv { + ec2MetadataClient.EXPECT().TargetLifecycleState().Return("", errors.New("error")) + ec2MetadataClient.EXPECT().TargetLifecycleState().Return(inServiceState, nil) + } + ec2MetadataClient.EXPECT().InstanceID().Return(instanceID, nil) + } var discoverEndpointsInvoked sync.WaitGroup discoverEndpointsInvoked.Add(2) @@ -347,6 +421,9 @@ func TestDoStartHappyPath(t *testing.T) { cfg := getTestConfig() cfg.ContainerMetadataEnabled = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} cfg.Checkpoint = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + if warmPoolsEnv { + cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + } cfg.Cluster = clusterName ctx, cancel := context.WithCancel(context.TODO()) @@ -386,7 +463,9 @@ func TestDoStartHappyPath(t *testing.T) { assertMetadata(t, data.AvailabilityZoneKey, availabilityZone, dataClient) assertMetadata(t, data.ClusterNameKey, clusterName, dataClient) assertMetadata(t, data.ContainerInstanceARNKey, containerInstanceARN, dataClient) - assertMetadata(t, data.EC2InstanceIDKey, instanceID, dataClient) + if !blackholed { + assertMetadata(t, data.EC2InstanceIDKey, instanceID, dataClient) + } } func assertMetadata(t *testing.T, key, expectedVal string, dataClient data.Client) { @@ -1195,6 +1274,8 @@ func TestRegisterContainerInstanceInvalidParameterTerminalError(t *testing.T) { credentialProvider: aws_credentials.NewCredentials(mockCredentialsProvider), dockerClient: dockerClient, mobyPlugins: mockMobyPlugins, + terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) { + }, } exitCode := agent.doStart(eventstream.NewEventStream("events", ctx), @@ -1473,3 +1554,45 @@ func newTestDataClient(t *testing.T) (data.Client, func()) { } return testClient, cleanup } + +type targetLifecycleFuncDetail struct { + val string + err error + returnTimes int +} + +func TestWaitUntilInstanceInServicePolling(t *testing.T) { + warmedResult := targetLifecycleFuncDetail{warmedState, nil, 1} + inServiceResult := targetLifecycleFuncDetail{inServiceState, nil, 1} + notFoundErrResult := targetLifecycleFuncDetail{"", notFoundErr, testTargetLifecycleMaxRetryCount} + unexpectedErrResult := targetLifecycleFuncDetail{"", badReqErr, testTargetLifecycleMaxRetryCount} + serverErrResult := targetLifecycleFuncDetail{"", serverErr, testTargetLifecycleMaxRetryCount} + testCases := []struct { + name string + funcTestDetails []targetLifecycleFuncDetail + result error + maxPolls int + }{ + {"TestWaitUntilInServicePollWarmed", []targetLifecycleFuncDetail{warmedResult, warmedResult, inServiceResult}, nil, asgLifecyclePollMax}, + {"TestWaitUntilInServicePollMissing", []targetLifecycleFuncDetail{notFoundErrResult, inServiceResult}, nil, asgLifecyclePollMax}, + {"TestWaitUntilInServiceErrPollMaxReached", []targetLifecycleFuncDetail{notFoundErrResult}, notFoundErr, 1}, + {"TestWaitUntilInServiceNoStateUnexpectedErr", []targetLifecycleFuncDetail{unexpectedErrResult}, badReqErr, asgLifecyclePollMax}, + {"TestWaitUntilInServiceUnexpectedErr", []targetLifecycleFuncDetail{warmedResult, unexpectedErrResult}, badReqErr, asgLifecyclePollMax}, + {"TestWaitUntilInServiceServerErrContinue", []targetLifecycleFuncDetail{warmedResult, serverErrResult, inServiceResult}, nil, asgLifecyclePollMax}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cfg := getTestConfig() + cfg.WarmPoolsSupport = config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled} + ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl) + agent := &ecsAgent{ec2MetadataClient: ec2MetadataClient, cfg: &cfg} + for _, detail := range tc.funcTestDetails { + ec2MetadataClient.EXPECT().TargetLifecycleState().Return(detail.val, detail.err).Times(detail.returnTimes) + } + assert.Equal(t, tc.result, agent.waitUntilInstanceInService(1*time.Millisecond, tc.maxPolls, testTargetLifecycleMaxRetryCount)) + }) + } +} diff --git a/agent/async/generate_mocks.go b/agent/async/generate_mocks.go index 04eeabc3f52..4b424471422 100644 --- a/agent/async/generate_mocks.go +++ b/agent/async/generate_mocks.go @@ -13,4 +13,4 @@ package async -//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache +//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache,TTLCache diff --git a/agent/async/mocks/async_mocks.go b/agent/async/mocks/async_mocks.go index 45e5c29900d..02c61092176 100644 --- a/agent/async/mocks/async_mocks.go +++ b/agent/async/mocks/async_mocks.go @@ -13,7 +13,7 @@ // // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/amazon-ecs-agent/agent/async (interfaces: Cache) +// Source: github.com/aws/amazon-ecs-agent/agent/async (interfaces: Cache,TTLCache) // Package mock_async is a generated GoMock package. package mock_async @@ -86,3 +86,66 @@ func (mr *MockCacheMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCache)(nil).Set), arg0, arg1) } + +// MockTTLCache is a mock of TTLCache interface +type MockTTLCache struct { + ctrl *gomock.Controller + recorder *MockTTLCacheMockRecorder +} + +// MockTTLCacheMockRecorder is the mock recorder for MockTTLCache +type MockTTLCacheMockRecorder struct { + mock *MockTTLCache +} + +// NewMockTTLCache creates a new mock instance +func NewMockTTLCache(ctrl *gomock.Controller) *MockTTLCache { + mock := &MockTTLCache{ctrl: ctrl} + mock.recorder = &MockTTLCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTTLCache) EXPECT() *MockTTLCacheMockRecorder { + return m.recorder +} + +// Delete mocks base method +func (m *MockTTLCache) Delete(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Delete", arg0) +} + +// Delete indicates an expected call of Delete +func (mr *MockTTLCacheMockRecorder) Delete(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockTTLCache)(nil).Delete), arg0) +} + +// Get mocks base method +func (m *MockTTLCache) Get(arg0 string) (interface{}, bool, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 +} + +// Get indicates an expected call of Get +func (mr *MockTTLCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockTTLCache)(nil).Get), arg0) +} + +// Set mocks base method +func (m *MockTTLCache) Set(arg0 string, arg1 interface{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Set", arg0, arg1) +} + +// Set indicates an expected call of Set +func (mr *MockTTLCacheMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockTTLCache)(nil).Set), arg0, arg1) +} diff --git a/agent/async/ttl_cache.go b/agent/async/ttl_cache.go new file mode 100644 index 00000000000..d90aa175c02 --- /dev/null +++ b/agent/async/ttl_cache.go @@ -0,0 +1,80 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package async + +import ( + "sync" + "time" +) + +type TTLCache interface { + // Get fetches a value from cache, returns nil, false on miss + Get(key string) (value interface{}, expired bool, ok bool) + // Set sets a value in cache. overrites any existing value + Set(key string, value interface{}) + // Delete deletes the value from the cache + Delete(key string) +} + +// Creates a TTL cache with ttl for items. +func NewTTLCache(ttl time.Duration) TTLCache { + return &ttlCache{ + ttl: ttl, + cache: make(map[string]*ttlCacheEntry), + } +} + +type ttlCacheEntry struct { + value interface{} + expiry time.Time +} + +type ttlCache struct { + mu sync.RWMutex + cache map[string]*ttlCacheEntry + ttl time.Duration +} + +// Get returns the value associated with the key. +// returns if the item is expired (true if key is expired). +// ok result indicates whether value was found in the map. +// Note that items are not automatically deleted from the map when they expire. They will continue to be +// returned with expired=true. +func (t *ttlCache) Get(key string) (value interface{}, expired bool, ok bool) { + t.mu.RLock() + defer t.mu.RUnlock() + if _, iok := t.cache[key]; !iok { + return nil, false, false + } + entry := t.cache[key] + expired = time.Now().After(entry.expiry) + return entry.value, expired, true +} + +// Set sets the key-value pair in the cache +func (t *ttlCache) Set(key string, value interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + t.cache[key] = &ttlCacheEntry{ + value: value, + expiry: time.Now().Add(t.ttl), + } +} + +// Delete removes the entry associated with the key from cache +func (t *ttlCache) Delete(key string) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.cache, key) +} diff --git a/agent/async/ttl_cache_test.go b/agent/async/ttl_cache_test.go new file mode 100644 index 00000000000..0a319d77b28 --- /dev/null +++ b/agent/async/ttl_cache_test.go @@ -0,0 +1,83 @@ +//go:build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package async + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTTLSimple(t *testing.T) { + ttl := NewTTLCache(time.Minute) + ttl.Set("foo", "bar") + + bar, expired, ok := ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar") + + baz, expired, ok := ttl.Get("fooz") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, baz) + + ttl.Delete("foo") + bar, expired, ok = ttl.Get("foo") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, bar) +} + +func TestTTLSetDelete(t *testing.T) { + ttl := NewTTLCache(time.Minute) + + ttl.Set("foo", "bar") + bar, expired, ok := ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar") + + ttl.Set("foo", "bar2") + bar, expired, ok = ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar2") + + ttl.Delete("foo") + bar, expired, ok = ttl.Get("foo") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, bar) +} + +func TestTTLCache(t *testing.T) { + ttl := NewTTLCache(50 * time.Millisecond) + ttl.Set("foo", "bar") + + bar, expired, ok := ttl.Get("foo") + require.False(t, expired) + require.True(t, ok) + require.Equal(t, bar, "bar") + + time.Sleep(100 * time.Millisecond) + + bar, expired, ok = ttl.Get("foo") + require.True(t, ok) + require.True(t, expired) + require.Equal(t, bar, "bar") +} diff --git a/agent/config/config.go b/agent/config/config.go index d9307b8cad3..b7e296df33a 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -592,6 +592,7 @@ func environmentConfig() (Config, error) { External: parseBooleanDefaultFalseConfig("ECS_EXTERNAL"), EnableRuntimeStats: parseBooleanDefaultFalseConfig("ECS_ENABLE_RUNTIME_STATS"), ShouldExcludeIPv6PortBinding: parseBooleanDefaultTrueConfig("ECS_EXCLUDE_IPV6_PORTBINDING"), + WarmPoolsSupport: parseBooleanDefaultFalseConfig("ECS_WARM_POOLS_CHECK"), }, err } diff --git a/agent/config/config_test.go b/agent/config/config_test.go index dd4aac1bb62..c2d14738364 100644 --- a/agent/config/config_test.go +++ b/agent/config/config_test.go @@ -158,6 +158,7 @@ func TestEnvironmentConfig(t *testing.T) { defer setTestEnv("ECS_PULL_DEPENDENT_CONTAINERS_UPFRONT", "true")() defer setTestEnv("ECS_ENABLE_RUNTIME_STATS", "true")() defer setTestEnv("ECS_EXCLUDE_IPV6_PORTBINDING", "true")() + defer setTestEnv("ECS_WARM_POOLS_CHECK", "false")() additionalLocalRoutesJSON := `["1.2.3.4/22","5.6.7.8/32"]` setTestEnv("ECS_AWSVPC_ADDITIONAL_LOCAL_ROUTES", additionalLocalRoutesJSON) setTestEnv("ECS_ENABLE_CONTAINER_METADATA", "true") @@ -216,6 +217,7 @@ func TestEnvironmentConfig(t *testing.T) { assert.True(t, conf.DependentContainersPullUpfront.Enabled(), "Wrong value for DependentContainersPullUpfront") assert.True(t, conf.EnableRuntimeStats.Enabled(), "Wrong value for EnableRuntimeStats") assert.True(t, conf.ShouldExcludeIPv6PortBinding.Enabled(), "Wrong value for ShouldExcludeIPv6PortBinding") + assert.False(t, conf.WarmPoolsSupport.Enabled(), "Wrong value for WarmPoolsSupport") } func TestTrimWhitespaceWhenCreating(t *testing.T) { diff --git a/agent/config/types.go b/agent/config/types.go index d0eafd2f9b1..2e533f41bf7 100644 --- a/agent/config/types.go +++ b/agent/config/types.go @@ -354,4 +354,8 @@ type Config struct { // is set to true by default, and can be overridden by the ECS_EXCLUDE_IPV6_PORTBINDING environment variable. This is a workaround // for docker's bug as detailed in https://github.com/aws/amazon-ecs-agent/issues/2870. ShouldExcludeIPv6PortBinding BooleanDefaultTrue + + // WarmPoolsSupport specifies whether the agent should poll IMDS to check the target lifecycle state for a starting + // instance + WarmPoolsSupport BooleanDefaultFalse } diff --git a/agent/ec2/blackhole_ec2_metadata_client.go b/agent/ec2/blackhole_ec2_metadata_client.go index eb7d33d8758..6e3ed5d5523 100644 --- a/agent/ec2/blackhole_ec2_metadata_client.go +++ b/agent/ec2/blackhole_ec2_metadata_client.go @@ -84,3 +84,7 @@ func (blackholeMetadataClient) SpotInstanceAction() (string, error) { func (blackholeMetadataClient) OutpostARN() (string, error) { return "", errors.New("blackholed") } + +func (blackholeMetadataClient) TargetLifecycleState() (string, error) { + return "", errors.New("blackholed") +} diff --git a/agent/ec2/ec2_metadata_client.go b/agent/ec2/ec2_metadata_client.go index 3ab16928e70..3486a4d8c8c 100644 --- a/agent/ec2/ec2_metadata_client.go +++ b/agent/ec2/ec2_metadata_client.go @@ -40,6 +40,7 @@ const ( PublicIPv4Resource = "public-ipv4" OutpostARN = "outpost-arn" PrimaryIPV4VPCCIDRResourceFormat = "network/interfaces/macs/%s/vpc-ipv4-cidr-block" + TargetLifecycleState = "autoscaling/target-lifecycle-state" ) const ( @@ -82,6 +83,7 @@ type EC2MetadataClient interface { PublicIPv4Address() (string, error) SpotInstanceAction() (string, error) OutpostARN() (string, error) + TargetLifecycleState() (string, error) } type ec2MetadataClientImpl struct { @@ -203,3 +205,7 @@ func (c *ec2MetadataClientImpl) SpotInstanceAction() (string, error) { func (c *ec2MetadataClientImpl) OutpostARN() (string, error) { return c.client.GetMetadata(OutpostARN) } + +func (c *ec2MetadataClientImpl) TargetLifecycleState() (string, error) { + return c.client.GetMetadata(TargetLifecycleState) +} diff --git a/agent/ec2/mocks/ec2_mocks.go b/agent/ec2/mocks/ec2_mocks.go index 8257b54cabf..bf25d6e2be5 100644 --- a/agent/ec2/mocks/ec2_mocks.go +++ b/agent/ec2/mocks/ec2_mocks.go @@ -261,6 +261,21 @@ func (mr *MockEC2MetadataClientMockRecorder) SubnetID(arg0 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubnetID", reflect.TypeOf((*MockEC2MetadataClient)(nil).SubnetID), arg0) } +// TargetLifecycleState mocks base method +func (m *MockEC2MetadataClient) TargetLifecycleState() (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TargetLifecycleState") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TargetLifecycleState indicates an expected call of TargetLifecycleState +func (mr *MockEC2MetadataClientMockRecorder) TargetLifecycleState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TargetLifecycleState", reflect.TypeOf((*MockEC2MetadataClient)(nil).TargetLifecycleState)) +} + // VPCID mocks base method func (m *MockEC2MetadataClient) VPCID(arg0 string) (string, error) { m.ctrl.T.Helper() diff --git a/agent/utils/utils.go b/agent/utils/utils.go index 78a51b2293f..341e3bc5dfb 100644 --- a/agent/utils/utils.go +++ b/agent/utils/utils.go @@ -159,6 +159,16 @@ func IsAWSErrorCodeEqual(err error, code string) bool { return ok && awsErr.Code() == code } +// GetRequestFailureStatusCode returns the status code from a +// RequestFailure error, or 0 if the error is not of that type +func GetRequestFailureStatusCode(err error) int { + var statusCode int + if reqErr, ok := err.(awserr.RequestFailure); ok { + statusCode = reqErr.StatusCode() + } + return statusCode +} + // MapToTags converts a map to a slice of tags. func MapToTags(tagsMap map[string]string) []*ecs.Tag { tags := make([]*ecs.Tag, 0) diff --git a/agent/utils/utils_test.go b/agent/utils/utils_test.go index 0b3fd70471f..1e477b6bfd9 100644 --- a/agent/utils/utils_test.go +++ b/agent/utils/utils_test.go @@ -162,6 +162,31 @@ func TestIsAWSErrorCodeEqual(t *testing.T) { } } +func TestGetRequestFailureStatusCode(t *testing.T) { + testcases := []struct { + name string + err error + res int + }{ + { + name: "TestGetRequestFailureStatusCodeSuccess", + err: awserr.NewRequestFailure(awserr.Error(awserr.New("BadRequest", "", errors.New(""))), 400, ""), + res: 400, + }, + { + name: "TestGetRequestFailureStatusCodeWrongErrType", + err: errors.New("err"), + res: 0, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.res, GetRequestFailureStatusCode(tc.err)) + }) + } +} + func TestMapToTags(t *testing.T) { tagKey1 := "tagKey1" tagKey2 := "tagKey2" diff --git a/agent/version/version.go b/agent/version/version.go index 7efe0b74ba6..1e674b8546c 100644 --- a/agent/version/version.go +++ b/agent/version/version.go @@ -22,10 +22,10 @@ package version // repository. Only the 'Version' const should change in checked-in source code // Version is the version of the Agent -const Version = "1.58.0" +const Version = "1.59.0" // GitDirty indicates the cleanliness of the git repo when this agent was built const GitDirty = true // GitShortHash is the short hash of this agent build -const GitShortHash = "6ff5d76d" +const GitShortHash = "3909ba91"