Skip to content

Commit

Permalink
DiscoverPollEndpoint: lengthen cache ttl and improve resiliency (#3109)
Browse files Browse the repository at this point in the history
The acs/tacs endpoints that ecs agent connects have never changed
before, so the current behavior of calling the API every 20 minutes is
not necessary. Lengthening this cache TTL to 12 hours will reduce load
on the ECS service and reduce the chance of customers being throttled.

Additionally, the LRU cache we were using automatically evicted any
value once it expired. This means that this API call could become a
source of complete ecs agent failure in the event of an LSE, in the
event that agent threw away it's cached endpoint and then failed to get
a new one via DiscoverPollEndpoint.

By keeping expired values in the cache, we can fallback to use the
expired value in the event that the DiscoverPollEndpoint API is failing,
thus keeping agent connected to ACS and TACS.
  • Loading branch information
sparrc authored Feb 3, 2022
1 parent 61668df commit 3650e1f
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 28 deletions.
29 changes: 19 additions & 10 deletions agent/api/ecsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ const (
ecsMaxImageDigestLength = 255
ecsMaxReasonLength = 255
ecsMaxRuntimeIDLength = 255
pollEndpointCacheSize = 1
pollEndpointCacheTTL = 20 * time.Minute
pollEndpointCacheTTL = 12 * time.Hour
roundtripTimeout = 5 * time.Second
azAttrName = "ecs.availability-zone"
cpuArchAttrName = "ecs.cpu-architecture"
Expand All @@ -56,7 +55,7 @@ type APIECSClient struct {
standardClient api.ECSSDK
submitStateChangeClient api.ECSSubmitStateSDK
ec2metadata ec2.EC2MetadataClient
pollEndpoinCache async.Cache
pollEndpointCache async.TTLCache
}

// NewECSClient creates a new ECSClient interface object
Expand All @@ -74,14 +73,13 @@ func NewECSClient(
}
standardClient := ecs.New(session.New(&ecsConfig))
submitStateChangeClient := newSubmitStateChangeClient(&ecsConfig)
pollEndpoinCache := async.NewLRUCache(pollEndpointCacheSize, pollEndpointCacheTTL)
return &APIECSClient{
credentialProvider: credentialProvider,
config: config,
standardClient: standardClient,
submitStateChangeClient: submitStateChangeClient,
ec2metadata: ec2MetadataClient,
pollEndpoinCache: pollEndpoinCache,
pollEndpointCache: async.NewTTLCache(pollEndpointCacheTTL),
}
}

Expand Down Expand Up @@ -585,26 +583,37 @@ func (client *APIECSClient) DiscoverTelemetryEndpoint(containerInstanceArn strin

func (client *APIECSClient) discoverPollEndpoint(containerInstanceArn string) (*ecs.DiscoverPollEndpointOutput, error) {
// Try getting an entry from the cache
cachedEndpoint, found := client.pollEndpoinCache.Get(containerInstanceArn)
if found {
// Cache hit. Return the output.
cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn)
if !expired && found {
// Cache hit and not expired. Return the output.
if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok {
seelog.Infof("Using cached DiscoverPollEndpoint. endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s",
aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn)
return output, nil
}
}

// Cache miss, invoke the ECS DiscoverPollEndpoint API.
// Cache miss or expired, invoke the ECS DiscoverPollEndpoint API.
seelog.Debugf("Invoking DiscoverPollEndpoint for '%s'", containerInstanceArn)
output, err := client.standardClient.DiscoverPollEndpoint(&ecs.DiscoverPollEndpointInput{
ContainerInstance: &containerInstanceArn,
Cluster: &client.config.Cluster,
})
if err != nil {
// if we got an error calling the API, fallback to an expired cached endpoint if
// we have it.
if expired {
if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok {
seelog.Infof("Error calling DiscoverPollEndpoint. Using cached but expired endpoint as a fallback. error=%s endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s",
err, aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn)
return output, nil
}
}
return nil, err
}

// Cache the response from ECS.
client.pollEndpoinCache.Set(containerInstanceArn, output)
client.pollEndpointCache.Set(containerInstanceArn, output)
return output, nil
}

Expand Down
67 changes: 51 additions & 16 deletions agent/api/ecsclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,23 +838,23 @@ func TestDiscoverPollEndpointCacheHit(t *testing.T) {
defer mockCtrl.Finish()

mockSDK := mock_api.NewMockECSSDK(mockCtrl)
pollEndpoinCache := mock_async.NewMockCache(mockCtrl)
pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl)
client := &APIECSClient{
credentialProvider: credentials.AnonymousCredentials,
config: &config.Config{
Cluster: configuredCluster,
AWSRegion: "us-east-1",
},
standardClient: mockSDK,
ec2metadata: ec2.NewBlackholeEC2MetadataClient(),
pollEndpoinCache: pollEndpoinCache,
standardClient: mockSDK,
ec2metadata: ec2.NewBlackholeEC2MetadataClient(),
pollEndpointCache: pollEndpointCache,
}

pollEndpoint := "http://127.0.0.1"
pollEndpoinCache.EXPECT().Get("containerInstance").Return(
pollEndpointCache.EXPECT().Get("containerInstance").Return(
&ecs.DiscoverPollEndpointOutput{
Endpoint: aws.String(pollEndpoint),
}, true)
}, false, true)
output, err := client.discoverPollEndpoint("containerInstance")
if err != nil {
t.Fatalf("Error in discoverPollEndpoint: %v", err)
Expand All @@ -869,26 +869,61 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) {
defer mockCtrl.Finish()

mockSDK := mock_api.NewMockECSSDK(mockCtrl)
pollEndpoinCache := mock_async.NewMockCache(mockCtrl)
pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl)
client := &APIECSClient{
credentialProvider: credentials.AnonymousCredentials,
config: &config.Config{
Cluster: configuredCluster,
AWSRegion: "us-east-1",
},
standardClient: mockSDK,
ec2metadata: ec2.NewBlackholeEC2MetadataClient(),
pollEndpoinCache: pollEndpoinCache,
standardClient: mockSDK,
ec2metadata: ec2.NewBlackholeEC2MetadataClient(),
pollEndpointCache: pollEndpointCache,
}
pollEndpoint := "http://127.0.0.1"
pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{
Endpoint: &pollEndpoint,
}

gomock.InOrder(
pollEndpoinCache.EXPECT().Get("containerInstance").Return(nil, false),
pollEndpointCache.EXPECT().Get("containerInstance").Return(nil, false, false),
mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(pollEndpointOutput, nil),
pollEndpoinCache.EXPECT().Set("containerInstance", pollEndpointOutput),
pollEndpointCache.EXPECT().Set("containerInstance", pollEndpointOutput),
)

output, err := client.discoverPollEndpoint("containerInstance")
if err != nil {
t.Fatalf("Error in discoverPollEndpoint: %v", err)
}
if aws.StringValue(output.Endpoint) != pollEndpoint {
t.Errorf("Mismatch in poll endpoint: %s != %s", aws.StringValue(output.Endpoint), pollEndpoint)
}
}

func TestDiscoverPollEndpointExpiredButDPEFailed(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockSDK := mock_api.NewMockECSSDK(mockCtrl)
pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl)
client := &APIECSClient{
credentialProvider: credentials.AnonymousCredentials,
config: &config.Config{
Cluster: configuredCluster,
AWSRegion: "us-east-1",
},
standardClient: mockSDK,
ec2metadata: ec2.NewBlackholeEC2MetadataClient(),
pollEndpointCache: pollEndpointCache,
}
pollEndpoint := "http://127.0.0.1"
pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{
Endpoint: &pollEndpoint,
}

gomock.InOrder(
pollEndpointCache.EXPECT().Get("containerInstance").Return(pollEndpointOutput, true, false),
mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("error!")),
)

output, err := client.discoverPollEndpoint("containerInstance")
Expand All @@ -905,16 +940,16 @@ func TestDiscoverTelemetryEndpointAfterPollEndpointCacheHit(t *testing.T) {
defer mockCtrl.Finish()

mockSDK := mock_api.NewMockECSSDK(mockCtrl)
pollEndpoinCache := async.NewLRUCache(1, 10*time.Minute)
pollEndpointCache := async.NewTTLCache(10 * time.Minute)
client := &APIECSClient{
credentialProvider: credentials.AnonymousCredentials,
config: &config.Config{
Cluster: configuredCluster,
AWSRegion: "us-east-1",
},
standardClient: mockSDK,
ec2metadata: ec2.NewBlackholeEC2MetadataClient(),
pollEndpoinCache: pollEndpoinCache,
standardClient: mockSDK,
ec2metadata: ec2.NewBlackholeEC2MetadataClient(),
pollEndpointCache: pollEndpointCache,
}

pollEndpoint := "http://127.0.0.1"
Expand Down
2 changes: 1 addition & 1 deletion agent/async/generate_mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

package async

//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache
//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache,TTLCache
65 changes: 64 additions & 1 deletion agent/async/mocks/async_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 80 additions & 0 deletions agent/async/ttl_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package async

import (
"sync"
"time"
)

type TTLCache interface {
// Get fetches a value from cache, returns nil, false on miss
Get(key string) (value interface{}, expired bool, ok bool)
// Set sets a value in cache. overrites any existing value
Set(key string, value interface{})
// Delete deletes the value from the cache
Delete(key string)
}

// Creates a TTL cache with ttl for items.
func NewTTLCache(ttl time.Duration) TTLCache {
return &ttlCache{
ttl: ttl,
cache: make(map[string]*ttlCacheEntry),
}
}

type ttlCacheEntry struct {
value interface{}
expiry time.Time
}

type ttlCache struct {
mu sync.RWMutex
cache map[string]*ttlCacheEntry
ttl time.Duration
}

// Get returns the value associated with the key.
// returns if the item is expired (true if key is expired).
// ok result indicates whether value was found in the map.
// Note that items are not automatically deleted from the map when they expire. They will continue to be
// returned with expired=true.
func (t *ttlCache) Get(key string) (value interface{}, expired bool, ok bool) {
t.mu.RLock()
defer t.mu.RUnlock()
if _, iok := t.cache[key]; !iok {
return nil, false, false
}
entry := t.cache[key]
expired = time.Now().After(entry.expiry)
return entry.value, expired, true
}

// Set sets the key-value pair in the cache
func (t *ttlCache) Set(key string, value interface{}) {
t.mu.Lock()
defer t.mu.Unlock()
t.cache[key] = &ttlCacheEntry{
value: value,
expiry: time.Now().Add(t.ttl),
}
}

// Delete removes the entry associated with the key from cache
func (t *ttlCache) Delete(key string) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.cache, key)
}
Loading

0 comments on commit 3650e1f

Please sign in to comment.