From 3650e1f25554a1b49a3e2d3c07948383f12820b5 Mon Sep 17 00:00:00 2001 From: Cameron Sparr Date: Thu, 3 Feb 2022 11:28:55 -0800 Subject: [PATCH] DiscoverPollEndpoint: lengthen cache ttl and improve resiliency (#3109) The acs/tacs endpoints that ecs agent connects have never changed before, so the current behavior of calling the API every 20 minutes is not necessary. Lengthening this cache TTL to 12 hours will reduce load on the ECS service and reduce the chance of customers being throttled. Additionally, the LRU cache we were using automatically evicted any value once it expired. This means that this API call could become a source of complete ecs agent failure in the event of an LSE, in the event that agent threw away it's cached endpoint and then failed to get a new one via DiscoverPollEndpoint. By keeping expired values in the cache, we can fallback to use the expired value in the event that the DiscoverPollEndpoint API is failing, thus keeping agent connected to ACS and TACS. --- agent/api/ecsclient/client.go | 29 +++++++---- agent/api/ecsclient/client_test.go | 67 ++++++++++++++++++------ agent/async/generate_mocks.go | 2 +- agent/async/mocks/async_mocks.go | 65 ++++++++++++++++++++++- agent/async/ttl_cache.go | 80 ++++++++++++++++++++++++++++ agent/async/ttl_cache_test.go | 83 ++++++++++++++++++++++++++++++ 6 files changed, 298 insertions(+), 28 deletions(-) create mode 100644 agent/async/ttl_cache.go create mode 100644 agent/async/ttl_cache_test.go diff --git a/agent/api/ecsclient/client.go b/agent/api/ecsclient/client.go index cf3cecd29e0..cb5c357d5a6 100644 --- a/agent/api/ecsclient/client.go +++ b/agent/api/ecsclient/client.go @@ -40,8 +40,7 @@ const ( ecsMaxImageDigestLength = 255 ecsMaxReasonLength = 255 ecsMaxRuntimeIDLength = 255 - pollEndpointCacheSize = 1 - pollEndpointCacheTTL = 20 * time.Minute + pollEndpointCacheTTL = 12 * time.Hour roundtripTimeout = 5 * time.Second azAttrName = "ecs.availability-zone" cpuArchAttrName = "ecs.cpu-architecture" @@ -56,7 +55,7 @@ type APIECSClient struct { standardClient api.ECSSDK submitStateChangeClient api.ECSSubmitStateSDK ec2metadata ec2.EC2MetadataClient - pollEndpoinCache async.Cache + pollEndpointCache async.TTLCache } // NewECSClient creates a new ECSClient interface object @@ -74,14 +73,13 @@ func NewECSClient( } standardClient := ecs.New(session.New(&ecsConfig)) submitStateChangeClient := newSubmitStateChangeClient(&ecsConfig) - pollEndpoinCache := async.NewLRUCache(pollEndpointCacheSize, pollEndpointCacheTTL) return &APIECSClient{ credentialProvider: credentialProvider, config: config, standardClient: standardClient, submitStateChangeClient: submitStateChangeClient, ec2metadata: ec2MetadataClient, - pollEndpoinCache: pollEndpoinCache, + pollEndpointCache: async.NewTTLCache(pollEndpointCacheTTL), } } @@ -585,26 +583,37 @@ func (client *APIECSClient) DiscoverTelemetryEndpoint(containerInstanceArn strin func (client *APIECSClient) discoverPollEndpoint(containerInstanceArn string) (*ecs.DiscoverPollEndpointOutput, error) { // Try getting an entry from the cache - cachedEndpoint, found := client.pollEndpoinCache.Get(containerInstanceArn) - if found { - // Cache hit. Return the output. + cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn) + if !expired && found { + // Cache hit and not expired. Return the output. if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { + seelog.Infof("Using cached DiscoverPollEndpoint. endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s", + aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn) return output, nil } } - // Cache miss, invoke the ECS DiscoverPollEndpoint API. + // Cache miss or expired, invoke the ECS DiscoverPollEndpoint API. seelog.Debugf("Invoking DiscoverPollEndpoint for '%s'", containerInstanceArn) output, err := client.standardClient.DiscoverPollEndpoint(&ecs.DiscoverPollEndpointInput{ ContainerInstance: &containerInstanceArn, Cluster: &client.config.Cluster, }) if err != nil { + // if we got an error calling the API, fallback to an expired cached endpoint if + // we have it. + if expired { + if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { + seelog.Infof("Error calling DiscoverPollEndpoint. Using cached but expired endpoint as a fallback. error=%s endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s", + err, aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn) + return output, nil + } + } return nil, err } // Cache the response from ECS. - client.pollEndpoinCache.Set(containerInstanceArn, output) + client.pollEndpointCache.Set(containerInstanceArn, output) return output, nil } diff --git a/agent/api/ecsclient/client_test.go b/agent/api/ecsclient/client_test.go index 2457c80919d..21fedd0ba82 100644 --- a/agent/api/ecsclient/client_test.go +++ b/agent/api/ecsclient/client_test.go @@ -838,23 +838,23 @@ func TestDiscoverPollEndpointCacheHit(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := mock_async.NewMockCache(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" - pollEndpoinCache.EXPECT().Get("containerInstance").Return( + pollEndpointCache.EXPECT().Get("containerInstance").Return( &ecs.DiscoverPollEndpointOutput{ Endpoint: aws.String(pollEndpoint), - }, true) + }, false, true) output, err := client.discoverPollEndpoint("containerInstance") if err != nil { t.Fatalf("Error in discoverPollEndpoint: %v", err) @@ -869,16 +869,16 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := mock_async.NewMockCache(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ @@ -886,9 +886,44 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) { } gomock.InOrder( - pollEndpoinCache.EXPECT().Get("containerInstance").Return(nil, false), + pollEndpointCache.EXPECT().Get("containerInstance").Return(nil, false, false), mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(pollEndpointOutput, nil), - pollEndpoinCache.EXPECT().Set("containerInstance", pollEndpointOutput), + pollEndpointCache.EXPECT().Set("containerInstance", pollEndpointOutput), + ) + + output, err := client.discoverPollEndpoint("containerInstance") + if err != nil { + t.Fatalf("Error in discoverPollEndpoint: %v", err) + } + if aws.StringValue(output.Endpoint) != pollEndpoint { + t.Errorf("Mismatch in poll endpoint: %s != %s", aws.StringValue(output.Endpoint), pollEndpoint) + } +} + +func TestDiscoverPollEndpointExpiredButDPEFailed(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockSDK := mock_api.NewMockECSSDK(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) + client := &APIECSClient{ + credentialProvider: credentials.AnonymousCredentials, + config: &config.Config{ + Cluster: configuredCluster, + AWSRegion: "us-east-1", + }, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, + } + pollEndpoint := "http://127.0.0.1" + pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ + Endpoint: &pollEndpoint, + } + + gomock.InOrder( + pollEndpointCache.EXPECT().Get("containerInstance").Return(pollEndpointOutput, true, false), + mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("error!")), ) output, err := client.discoverPollEndpoint("containerInstance") @@ -905,16 +940,16 @@ func TestDiscoverTelemetryEndpointAfterPollEndpointCacheHit(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := async.NewLRUCache(1, 10*time.Minute) + pollEndpointCache := async.NewTTLCache(10 * time.Minute) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" diff --git a/agent/async/generate_mocks.go b/agent/async/generate_mocks.go index 04eeabc3f52..4b424471422 100644 --- a/agent/async/generate_mocks.go +++ b/agent/async/generate_mocks.go @@ -13,4 +13,4 @@ package async -//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache +//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache,TTLCache diff --git a/agent/async/mocks/async_mocks.go b/agent/async/mocks/async_mocks.go index 45e5c29900d..02c61092176 100644 --- a/agent/async/mocks/async_mocks.go +++ b/agent/async/mocks/async_mocks.go @@ -13,7 +13,7 @@ // // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/amazon-ecs-agent/agent/async (interfaces: Cache) +// Source: github.com/aws/amazon-ecs-agent/agent/async (interfaces: Cache,TTLCache) // Package mock_async is a generated GoMock package. package mock_async @@ -86,3 +86,66 @@ func (mr *MockCacheMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCache)(nil).Set), arg0, arg1) } + +// MockTTLCache is a mock of TTLCache interface +type MockTTLCache struct { + ctrl *gomock.Controller + recorder *MockTTLCacheMockRecorder +} + +// MockTTLCacheMockRecorder is the mock recorder for MockTTLCache +type MockTTLCacheMockRecorder struct { + mock *MockTTLCache +} + +// NewMockTTLCache creates a new mock instance +func NewMockTTLCache(ctrl *gomock.Controller) *MockTTLCache { + mock := &MockTTLCache{ctrl: ctrl} + mock.recorder = &MockTTLCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTTLCache) EXPECT() *MockTTLCacheMockRecorder { + return m.recorder +} + +// Delete mocks base method +func (m *MockTTLCache) Delete(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Delete", arg0) +} + +// Delete indicates an expected call of Delete +func (mr *MockTTLCacheMockRecorder) Delete(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockTTLCache)(nil).Delete), arg0) +} + +// Get mocks base method +func (m *MockTTLCache) Get(arg0 string) (interface{}, bool, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 +} + +// Get indicates an expected call of Get +func (mr *MockTTLCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockTTLCache)(nil).Get), arg0) +} + +// Set mocks base method +func (m *MockTTLCache) Set(arg0 string, arg1 interface{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Set", arg0, arg1) +} + +// Set indicates an expected call of Set +func (mr *MockTTLCacheMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockTTLCache)(nil).Set), arg0, arg1) +} diff --git a/agent/async/ttl_cache.go b/agent/async/ttl_cache.go new file mode 100644 index 00000000000..d90aa175c02 --- /dev/null +++ b/agent/async/ttl_cache.go @@ -0,0 +1,80 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package async + +import ( + "sync" + "time" +) + +type TTLCache interface { + // Get fetches a value from cache, returns nil, false on miss + Get(key string) (value interface{}, expired bool, ok bool) + // Set sets a value in cache. overrites any existing value + Set(key string, value interface{}) + // Delete deletes the value from the cache + Delete(key string) +} + +// Creates a TTL cache with ttl for items. +func NewTTLCache(ttl time.Duration) TTLCache { + return &ttlCache{ + ttl: ttl, + cache: make(map[string]*ttlCacheEntry), + } +} + +type ttlCacheEntry struct { + value interface{} + expiry time.Time +} + +type ttlCache struct { + mu sync.RWMutex + cache map[string]*ttlCacheEntry + ttl time.Duration +} + +// Get returns the value associated with the key. +// returns if the item is expired (true if key is expired). +// ok result indicates whether value was found in the map. +// Note that items are not automatically deleted from the map when they expire. They will continue to be +// returned with expired=true. +func (t *ttlCache) Get(key string) (value interface{}, expired bool, ok bool) { + t.mu.RLock() + defer t.mu.RUnlock() + if _, iok := t.cache[key]; !iok { + return nil, false, false + } + entry := t.cache[key] + expired = time.Now().After(entry.expiry) + return entry.value, expired, true +} + +// Set sets the key-value pair in the cache +func (t *ttlCache) Set(key string, value interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + t.cache[key] = &ttlCacheEntry{ + value: value, + expiry: time.Now().Add(t.ttl), + } +} + +// Delete removes the entry associated with the key from cache +func (t *ttlCache) Delete(key string) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.cache, key) +} diff --git a/agent/async/ttl_cache_test.go b/agent/async/ttl_cache_test.go new file mode 100644 index 00000000000..0a319d77b28 --- /dev/null +++ b/agent/async/ttl_cache_test.go @@ -0,0 +1,83 @@ +//go:build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package async + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTTLSimple(t *testing.T) { + ttl := NewTTLCache(time.Minute) + ttl.Set("foo", "bar") + + bar, expired, ok := ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar") + + baz, expired, ok := ttl.Get("fooz") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, baz) + + ttl.Delete("foo") + bar, expired, ok = ttl.Get("foo") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, bar) +} + +func TestTTLSetDelete(t *testing.T) { + ttl := NewTTLCache(time.Minute) + + ttl.Set("foo", "bar") + bar, expired, ok := ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar") + + ttl.Set("foo", "bar2") + bar, expired, ok = ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar2") + + ttl.Delete("foo") + bar, expired, ok = ttl.Get("foo") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, bar) +} + +func TestTTLCache(t *testing.T) { + ttl := NewTTLCache(50 * time.Millisecond) + ttl.Set("foo", "bar") + + bar, expired, ok := ttl.Get("foo") + require.False(t, expired) + require.True(t, ok) + require.Equal(t, bar, "bar") + + time.Sleep(100 * time.Millisecond) + + bar, expired, ok = ttl.Get("foo") + require.True(t, ok) + require.True(t, expired) + require.Equal(t, bar, "bar") +}