diff --git a/agent/api/ecsclient/client.go b/agent/api/ecsclient/client.go index cf3cecd29e0..cb5c357d5a6 100644 --- a/agent/api/ecsclient/client.go +++ b/agent/api/ecsclient/client.go @@ -40,8 +40,7 @@ const ( ecsMaxImageDigestLength = 255 ecsMaxReasonLength = 255 ecsMaxRuntimeIDLength = 255 - pollEndpointCacheSize = 1 - pollEndpointCacheTTL = 20 * time.Minute + pollEndpointCacheTTL = 12 * time.Hour roundtripTimeout = 5 * time.Second azAttrName = "ecs.availability-zone" cpuArchAttrName = "ecs.cpu-architecture" @@ -56,7 +55,7 @@ type APIECSClient struct { standardClient api.ECSSDK submitStateChangeClient api.ECSSubmitStateSDK ec2metadata ec2.EC2MetadataClient - pollEndpoinCache async.Cache + pollEndpointCache async.TTLCache } // NewECSClient creates a new ECSClient interface object @@ -74,14 +73,13 @@ func NewECSClient( } standardClient := ecs.New(session.New(&ecsConfig)) submitStateChangeClient := newSubmitStateChangeClient(&ecsConfig) - pollEndpoinCache := async.NewLRUCache(pollEndpointCacheSize, pollEndpointCacheTTL) return &APIECSClient{ credentialProvider: credentialProvider, config: config, standardClient: standardClient, submitStateChangeClient: submitStateChangeClient, ec2metadata: ec2MetadataClient, - pollEndpoinCache: pollEndpoinCache, + pollEndpointCache: async.NewTTLCache(pollEndpointCacheTTL), } } @@ -585,26 +583,37 @@ func (client *APIECSClient) DiscoverTelemetryEndpoint(containerInstanceArn strin func (client *APIECSClient) discoverPollEndpoint(containerInstanceArn string) (*ecs.DiscoverPollEndpointOutput, error) { // Try getting an entry from the cache - cachedEndpoint, found := client.pollEndpoinCache.Get(containerInstanceArn) - if found { - // Cache hit. Return the output. + cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn) + if !expired && found { + // Cache hit and not expired. Return the output. if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { + seelog.Infof("Using cached DiscoverPollEndpoint. endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s", + aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn) return output, nil } } - // Cache miss, invoke the ECS DiscoverPollEndpoint API. + // Cache miss or expired, invoke the ECS DiscoverPollEndpoint API. seelog.Debugf("Invoking DiscoverPollEndpoint for '%s'", containerInstanceArn) output, err := client.standardClient.DiscoverPollEndpoint(&ecs.DiscoverPollEndpointInput{ ContainerInstance: &containerInstanceArn, Cluster: &client.config.Cluster, }) if err != nil { + // if we got an error calling the API, fallback to an expired cached endpoint if + // we have it. + if expired { + if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { + seelog.Infof("Error calling DiscoverPollEndpoint. Using cached but expired endpoint as a fallback. error=%s endpoint=%s telemetryEndpoint=%s containerInstanceARN=%s", + err, aws.StringValue(output.Endpoint), aws.StringValue(output.TelemetryEndpoint), containerInstanceArn) + return output, nil + } + } return nil, err } // Cache the response from ECS. - client.pollEndpoinCache.Set(containerInstanceArn, output) + client.pollEndpointCache.Set(containerInstanceArn, output) return output, nil } diff --git a/agent/api/ecsclient/client_test.go b/agent/api/ecsclient/client_test.go index 2457c80919d..21fedd0ba82 100644 --- a/agent/api/ecsclient/client_test.go +++ b/agent/api/ecsclient/client_test.go @@ -838,23 +838,23 @@ func TestDiscoverPollEndpointCacheHit(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := mock_async.NewMockCache(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" - pollEndpoinCache.EXPECT().Get("containerInstance").Return( + pollEndpointCache.EXPECT().Get("containerInstance").Return( &ecs.DiscoverPollEndpointOutput{ Endpoint: aws.String(pollEndpoint), - }, true) + }, false, true) output, err := client.discoverPollEndpoint("containerInstance") if err != nil { t.Fatalf("Error in discoverPollEndpoint: %v", err) @@ -869,16 +869,16 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := mock_async.NewMockCache(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ @@ -886,9 +886,44 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) { } gomock.InOrder( - pollEndpoinCache.EXPECT().Get("containerInstance").Return(nil, false), + pollEndpointCache.EXPECT().Get("containerInstance").Return(nil, false, false), mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(pollEndpointOutput, nil), - pollEndpoinCache.EXPECT().Set("containerInstance", pollEndpointOutput), + pollEndpointCache.EXPECT().Set("containerInstance", pollEndpointOutput), + ) + + output, err := client.discoverPollEndpoint("containerInstance") + if err != nil { + t.Fatalf("Error in discoverPollEndpoint: %v", err) + } + if aws.StringValue(output.Endpoint) != pollEndpoint { + t.Errorf("Mismatch in poll endpoint: %s != %s", aws.StringValue(output.Endpoint), pollEndpoint) + } +} + +func TestDiscoverPollEndpointExpiredButDPEFailed(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockSDK := mock_api.NewMockECSSDK(mockCtrl) + pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) + client := &APIECSClient{ + credentialProvider: credentials.AnonymousCredentials, + config: &config.Config{ + Cluster: configuredCluster, + AWSRegion: "us-east-1", + }, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, + } + pollEndpoint := "http://127.0.0.1" + pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ + Endpoint: &pollEndpoint, + } + + gomock.InOrder( + pollEndpointCache.EXPECT().Get("containerInstance").Return(pollEndpointOutput, true, false), + mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("error!")), ) output, err := client.discoverPollEndpoint("containerInstance") @@ -905,16 +940,16 @@ func TestDiscoverTelemetryEndpointAfterPollEndpointCacheHit(t *testing.T) { defer mockCtrl.Finish() mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpoinCache := async.NewLRUCache(1, 10*time.Minute) + pollEndpointCache := async.NewTTLCache(10 * time.Minute) client := &APIECSClient{ credentialProvider: credentials.AnonymousCredentials, config: &config.Config{ Cluster: configuredCluster, AWSRegion: "us-east-1", }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpoinCache: pollEndpoinCache, + standardClient: mockSDK, + ec2metadata: ec2.NewBlackholeEC2MetadataClient(), + pollEndpointCache: pollEndpointCache, } pollEndpoint := "http://127.0.0.1" diff --git a/agent/async/generate_mocks.go b/agent/async/generate_mocks.go index 04eeabc3f52..4b424471422 100644 --- a/agent/async/generate_mocks.go +++ b/agent/async/generate_mocks.go @@ -13,4 +13,4 @@ package async -//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache +//go:generate mockgen -destination=mocks/async_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/async Cache,TTLCache diff --git a/agent/async/mocks/async_mocks.go b/agent/async/mocks/async_mocks.go index 45e5c29900d..02c61092176 100644 --- a/agent/async/mocks/async_mocks.go +++ b/agent/async/mocks/async_mocks.go @@ -13,7 +13,7 @@ // // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/amazon-ecs-agent/agent/async (interfaces: Cache) +// Source: github.com/aws/amazon-ecs-agent/agent/async (interfaces: Cache,TTLCache) // Package mock_async is a generated GoMock package. package mock_async @@ -86,3 +86,66 @@ func (mr *MockCacheMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCache)(nil).Set), arg0, arg1) } + +// MockTTLCache is a mock of TTLCache interface +type MockTTLCache struct { + ctrl *gomock.Controller + recorder *MockTTLCacheMockRecorder +} + +// MockTTLCacheMockRecorder is the mock recorder for MockTTLCache +type MockTTLCacheMockRecorder struct { + mock *MockTTLCache +} + +// NewMockTTLCache creates a new mock instance +func NewMockTTLCache(ctrl *gomock.Controller) *MockTTLCache { + mock := &MockTTLCache{ctrl: ctrl} + mock.recorder = &MockTTLCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTTLCache) EXPECT() *MockTTLCacheMockRecorder { + return m.recorder +} + +// Delete mocks base method +func (m *MockTTLCache) Delete(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Delete", arg0) +} + +// Delete indicates an expected call of Delete +func (mr *MockTTLCacheMockRecorder) Delete(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockTTLCache)(nil).Delete), arg0) +} + +// Get mocks base method +func (m *MockTTLCache) Get(arg0 string) (interface{}, bool, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 +} + +// Get indicates an expected call of Get +func (mr *MockTTLCacheMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockTTLCache)(nil).Get), arg0) +} + +// Set mocks base method +func (m *MockTTLCache) Set(arg0 string, arg1 interface{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Set", arg0, arg1) +} + +// Set indicates an expected call of Set +func (mr *MockTTLCacheMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockTTLCache)(nil).Set), arg0, arg1) +} diff --git a/agent/async/ttl_cache.go b/agent/async/ttl_cache.go new file mode 100644 index 00000000000..d90aa175c02 --- /dev/null +++ b/agent/async/ttl_cache.go @@ -0,0 +1,80 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package async + +import ( + "sync" + "time" +) + +type TTLCache interface { + // Get fetches a value from cache, returns nil, false on miss + Get(key string) (value interface{}, expired bool, ok bool) + // Set sets a value in cache. overrites any existing value + Set(key string, value interface{}) + // Delete deletes the value from the cache + Delete(key string) +} + +// Creates a TTL cache with ttl for items. +func NewTTLCache(ttl time.Duration) TTLCache { + return &ttlCache{ + ttl: ttl, + cache: make(map[string]*ttlCacheEntry), + } +} + +type ttlCacheEntry struct { + value interface{} + expiry time.Time +} + +type ttlCache struct { + mu sync.RWMutex + cache map[string]*ttlCacheEntry + ttl time.Duration +} + +// Get returns the value associated with the key. +// returns if the item is expired (true if key is expired). +// ok result indicates whether value was found in the map. +// Note that items are not automatically deleted from the map when they expire. They will continue to be +// returned with expired=true. +func (t *ttlCache) Get(key string) (value interface{}, expired bool, ok bool) { + t.mu.RLock() + defer t.mu.RUnlock() + if _, iok := t.cache[key]; !iok { + return nil, false, false + } + entry := t.cache[key] + expired = time.Now().After(entry.expiry) + return entry.value, expired, true +} + +// Set sets the key-value pair in the cache +func (t *ttlCache) Set(key string, value interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + t.cache[key] = &ttlCacheEntry{ + value: value, + expiry: time.Now().Add(t.ttl), + } +} + +// Delete removes the entry associated with the key from cache +func (t *ttlCache) Delete(key string) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.cache, key) +} diff --git a/agent/async/ttl_cache_test.go b/agent/async/ttl_cache_test.go new file mode 100644 index 00000000000..0a319d77b28 --- /dev/null +++ b/agent/async/ttl_cache_test.go @@ -0,0 +1,83 @@ +//go:build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package async + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTTLSimple(t *testing.T) { + ttl := NewTTLCache(time.Minute) + ttl.Set("foo", "bar") + + bar, expired, ok := ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar") + + baz, expired, ok := ttl.Get("fooz") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, baz) + + ttl.Delete("foo") + bar, expired, ok = ttl.Get("foo") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, bar) +} + +func TestTTLSetDelete(t *testing.T) { + ttl := NewTTLCache(time.Minute) + + ttl.Set("foo", "bar") + bar, expired, ok := ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar") + + ttl.Set("foo", "bar2") + bar, expired, ok = ttl.Get("foo") + require.True(t, ok) + require.False(t, expired) + require.Equal(t, bar, "bar2") + + ttl.Delete("foo") + bar, expired, ok = ttl.Get("foo") + require.False(t, ok) + require.False(t, expired) + require.Nil(t, bar) +} + +func TestTTLCache(t *testing.T) { + ttl := NewTTLCache(50 * time.Millisecond) + ttl.Set("foo", "bar") + + bar, expired, ok := ttl.Get("foo") + require.False(t, expired) + require.True(t, ok) + require.Equal(t, bar, "bar") + + time.Sleep(100 * time.Millisecond) + + bar, expired, ok = ttl.Get("foo") + require.True(t, ok) + require.True(t, expired) + require.Equal(t, bar, "bar") +}