From 67ca4976560ab33d9d3d0b37273b25e852537050 Mon Sep 17 00:00:00 2001 From: Anuj Singh Date: Tue, 21 Feb 2023 22:00:56 -0800 Subject: [PATCH] periodically disconnect from acs --- agent/acs/handler/acs_handler.go | 48 ++++++++++++---- agent/acs/handler/acs_handler_test.go | 81 ++++++++++++++++++++++++--- agent/wsclient/client.go | 20 +++++-- agent/wsclient/mock/client.go | 14 +++++ agent/wsclient/wsconn/conn.go | 1 + agent/wsclient/wsconn/mock/conn.go | 14 +++++ 6 files changed, 156 insertions(+), 22 deletions(-) diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index f65ccfaba3a..2c518f538b6 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -39,6 +39,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/utils/ttime" "github.com/aws/amazon-ecs-agent/agent/version" "github.com/aws/amazon-ecs-agent/agent/wsclient" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/cihub/seelog" ) @@ -54,6 +55,10 @@ const ( inactiveInstanceReconnectDelay = 1 * time.Hour + // connectionTime is the maximum time after which agent closes its connection to ACS + connectionTime = 15 * time.Minute + connectionJitter = 30 * time.Minute + connectionBackoffMin = 250 * time.Millisecond connectionBackoffMax = 2 * time.Minute connectionBackoffJitter = 0.2 @@ -100,6 +105,8 @@ type session struct { doctor *doctor.Doctor _heartbeatTimeout time.Duration _heartbeatJitter time.Duration + connectionTime time.Duration + connectionJitter time.Duration _inactiveInstanceReconnectDelay time.Duration } @@ -107,7 +114,7 @@ type session struct { // a session with ACS. This interface is intended to define methods // that create resources used to establish the connection to ACS // It is confined to just the createACSClient() method for now. It can be -// extended to include the acsWsURL() and newDisconnectionTimer() methods +// extended to include the acsWsURL() and newHeartbeatTimer() methods // when needed // The goal is to make it easier to test and inject dependencies type sessionResources interface { @@ -182,6 +189,8 @@ func NewSession( doctor: doctor, _heartbeatTimeout: heartbeatTimeout, _heartbeatJitter: heartbeatJitter, + connectionTime: connectionTime, + connectionJitter: connectionJitter, _inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, } } @@ -224,7 +233,7 @@ func (acsSession *session) Start() error { } } if shouldReconnectWithoutBackoff(acsError) { - // If ACS closed the connection, there's no need to backoff, + // If ACS or agent closed the connection, there's no need to backoff, // reconnect immediately seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError) acsSession.backoff.Reset() @@ -360,11 +369,15 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { } seelog.Info("Connected to ACS endpoint") - // Start inactivity timer for closing the connection - timer := newDisconnectionTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()) - // Any message from the server resets the disconnect timeout - client.SetAnyRequestHandler(anyMessageHandler(timer, client)) - defer timer.Stop() + // Start a connection timer; agent will close its ACS websocket connection after this timer expires + connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter) + defer connectionTimer.Stop() + + // Start a heartbeat timer for closing the connection + heartbeatTimer := newHeartbeatTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()) + // Any message from the server resets the heartbeat timer + client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client)) + defer heartbeatTimer.Stop() acsSession.resources.connectedToACS() @@ -393,7 +406,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { case err := <-serveErr: // Stop receiving and sending messages from and to ACS when // client.Serve returns an error. This can happen when the - // the connection is closed by ACS or the agent + // connection is closed by ACS or the agent if err == nil || err == io.EOF { seelog.Info("ACS Websocket connection closed for a valid reason") } else { @@ -478,9 +491,9 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine. return acsURL + "?" + query.Encode() } -// newDisconnectionTimer creates a new time object, with a callback to +// newHeartbeatTimer creates a new time object, with a callback to // disconnect from ACS on inactivity -func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer { +func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer { timer := time.AfterFunc(retry.AddJitter(timeout, jitter), func() { seelog.Warn("ACS Connection hasn't had any activity for too long; closing connection") if err := client.Close(); err != nil { @@ -492,6 +505,21 @@ func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration, return timer } +// newConnectionTimer creates a new timer, after which agent closes its ACS websocket connection +func newConnectionTimer(client wsclient.ClientServer, connectionTime time.Duration, connectionJitter time.Duration) ttime.Timer { + expiresAt := retry.AddJitter(connectionTime, connectionJitter) + timer := time.AfterFunc(expiresAt, func() { + seelog.Infof("Closing ACS websocket connection after %v minutes", expiresAt.Minutes()) + // WriteCloseMessage() writes a close message using websocket control messages + // Ref: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages + err := client.WriteCloseMessage() + if err != nil { + seelog.Warnf("Error writing close message: %v", err) + } + }) + return timer +} + // anyMessageHandler handles any server message. Any server message means the // connection is active and thus the heartbeat disconnect should not occur func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) { diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index ed25e268d55..94428076614 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -16,6 +16,7 @@ package handler import ( + "context" "fmt" "io" "net/http" @@ -30,8 +31,6 @@ import ( "testing" "time" - "context" - apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" @@ -45,19 +44,18 @@ import ( mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" "github.com/aws/amazon-ecs-agent/agent/eventhandler" "github.com/aws/amazon-ecs-agent/agent/eventstream" - "github.com/aws/amazon-ecs-agent/agent/utils/retry" mock_retry "github.com/aws/amazon-ecs-agent/agent/utils/retry/mock" "github.com/aws/amazon-ecs-agent/agent/version" "github.com/aws/amazon-ecs-agent/agent/wsclient" mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/golang/mock/gomock" "github.com/gorilla/websocket" "github.com/pkg/errors" "github.com/stretchr/testify/assert" - - "github.com/golang/mock/gomock" ) const ( @@ -210,7 +208,7 @@ func TestHandlerReconnectsOnConnectErrors(t *testing.T) { mockWsClient.EXPECT().Connect().Return(io.EOF).Times(10), // Cancel trying to connect to ACS on the 11th attempt // Failure to retry on Connect() errors should cause the - // test to time out as the context is never cancelled + // test to time out as the context is never cancelled mockWsClient.EXPECT().Connect().Do(func() { cancel() }).Return(nil).MinTimes(1), @@ -575,7 +573,7 @@ func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) { } // TestHandlerReconnectsOnServeErrors tests if the handler retries to -// to establish the session with ACS when ClientServer.Connect() returns errors +// establish the session with ACS when ClientServer.Serve() returns errors func TestHandlerReconnectsOnServeErrors(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -592,6 +590,7 @@ func TestHandlerReconnectsOnServeErrors(t *testing.T) { mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() gomock.InOrder( // Serve fails 10 times @@ -647,6 +646,7 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) { mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() gomock.InOrder( mockWsClient.EXPECT().Serve().Return(io.EOF), @@ -668,6 +668,8 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) { resources: &mockSessionResources{mockWsClient}, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, + connectionTime: 30 * time.Millisecond, + connectionJitter: 10 * time.Millisecond, } // The session error channel would have an event when the Start() method returns @@ -695,6 +697,7 @@ func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) { mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Serve().AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() mockWsClient.EXPECT().Connect().Do(func() { // Serve() cancels the context @@ -721,6 +724,8 @@ func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) { resources: &mockSessionResources{mockWsClient}, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, + connectionTime: 30 * time.Millisecond, + connectionJitter: 10 * time.Millisecond, } go func() { acsSession.Start() @@ -771,7 +776,7 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { // been breached while Serving requests time.Sleep(30 * time.Millisecond) }).Return(io.EOF) - + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() connectionClosed := make(chan bool) mockWsClient.EXPECT().Close().Do(func() { wait.Wait() @@ -791,6 +796,8 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { resources: &mockSessionResources{}, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, + connectionTime: 30 * time.Millisecond, + connectionJitter: 10 * time.Millisecond, } go acsSession.startACSSession(mockWsClient) @@ -799,6 +806,61 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { <-connectionClosed } +// TestConnectionIsClosedAfterTimeIsUp tests if the connection to ACS is closed +// when the session's connection time is expired. +func TestConnectionIsClosedAfterTimeIsUp(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + taskEngine := mock_engine.NewMockTaskEngine(ctrl) + taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() + + ecsClient := mock_api.NewMockECSClient(ctrl) + ctx, cancel := context.WithCancel(context.Background()) + taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) + defer cancel() + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() + mockWsClient.EXPECT().Connect().Return(nil) + mockWsClient.EXPECT().Serve().Do(func() { + // pretend as if the connectionTime has elapsed + time.Sleep(30 * time.Millisecond) + cancel() + }).Return(io.EOF) + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + + // set connectionTime to a value lower than the heartbeatTimeout to avoid + // closing the connection due to the heartbeatTimer's callback func + acsSession := session{ + containerInstanceARN: "myArn", + credentialsProvider: testCreds, + agentConfig: testConfig, + taskEngine: taskEngine, + ecsClient: ecsClient, + dataClient: data.NewNoopClient(), + taskHandler: taskHandler, + ctx: context.Background(), + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), + resources: &mockSessionResources{}, + _heartbeatTimeout: 50 * time.Millisecond, + _heartbeatJitter: 10 * time.Millisecond, + connectionTime: 20 * time.Millisecond, + connectionJitter: 10 * time.Millisecond, + } + + go func() { + messageError := make(chan error) + messageError <- acsSession.startACSSession(mockWsClient) + assert.EqualError(t, <-messageError, io.EOF.Error()) + }() + + // Wait for context to be cancelled + select { + case <-ctx.Done(): + } +} + func TestHandlerDoesntLeakGoroutines(t *testing.T) { // Skip this test on "windows" platform as we have observed this to // fail often after upgrading the windows builds to golang v1.17. @@ -1036,6 +1098,7 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() mockWsClient.EXPECT().Serve().Return(io.EOF).AnyTimes() @@ -1068,6 +1131,8 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, + connectionTime: 30 * time.Millisecond, + connectionJitter: 10 * time.Millisecond, } go func() { for i := 0; i < 10; i++ { diff --git a/agent/wsclient/client.go b/agent/wsclient/client.go index 04476054451..3cfe9ee8fae 100644 --- a/agent/wsclient/client.go +++ b/agent/wsclient/client.go @@ -20,6 +20,7 @@ package wsclient import ( "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -33,14 +34,12 @@ import ( "sync" "time" - "github.com/aws/amazon-ecs-agent/agent/logger" - - "crypto/tls" - "github.com/aws/amazon-ecs-agent/agent/config" + "github.com/aws/amazon-ecs-agent/agent/logger" "github.com/aws/amazon-ecs-agent/agent/utils" "github.com/aws/amazon-ecs-agent/agent/utils/cipher" "github.com/aws/amazon-ecs-agent/agent/wsclient/wsconn" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" "github.com/cihub/seelog" @@ -98,6 +97,7 @@ type ClientServer interface { SetAnyRequestHandler(RequestHandler) MakeRequest(input interface{}) error WriteMessage(input []byte) error + WriteCloseMessage() error Connect() error IsConnected() bool SetConnection(conn wsconn.WebsocketConn) @@ -370,6 +370,18 @@ func (cs *ClientServerImpl) WriteMessage(send []byte) error { return cs.conn.WriteMessage(websocket.TextMessage, send) } +// WriteCloseMessage wraps the low level websocket WriteControl method with a lock, and sends a message of type +// CloseMessage (Ref: https://github.com/gorilla/websocket/blob/9111bb834a68b893cebbbaed5060bdbc1d9ab7d2/conn.go#L74) +func (cs *ClientServerImpl) WriteCloseMessage() error { + cs.writeLock.Lock() + defer cs.writeLock.Unlock() + + send := websocket.FormatCloseMessage(websocket.CloseNormalClosure, + "ConnectionExpired: Reconnect to continue") + + return cs.conn.WriteControl(websocket.CloseMessage, send, time.Now().Add(cs.RWTimeout)) +} + // ConsumeMessages reads messages from the websocket connection and handles read // messages from an active connection. func (cs *ClientServerImpl) ConsumeMessages() error { diff --git a/agent/wsclient/mock/client.go b/agent/wsclient/mock/client.go index c56b6cbfce6..58ef97acc5b 100644 --- a/agent/wsclient/mock/client.go +++ b/agent/wsclient/mock/client.go @@ -188,6 +188,20 @@ func (mr *MockClientServerMockRecorder) SetReadDeadline(arg0 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockClientServer)(nil).SetReadDeadline), arg0) } +// WriteCloseMessage mocks base method +func (m *MockClientServer) WriteCloseMessage() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteCloseMessage") + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteCloseMessage indicates an expected call of WriteCloseMessage +func (mr *MockClientServerMockRecorder) WriteCloseMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteCloseMessage", reflect.TypeOf((*MockClientServer)(nil).WriteCloseMessage)) +} + // WriteMessage mocks base method func (m *MockClientServer) WriteMessage(arg0 []byte) error { m.ctrl.T.Helper() diff --git a/agent/wsclient/wsconn/conn.go b/agent/wsclient/wsconn/conn.go index cb99691ca07..af29a55c93c 100644 --- a/agent/wsclient/wsconn/conn.go +++ b/agent/wsclient/wsconn/conn.go @@ -19,6 +19,7 @@ import "time" // connection's methods that this client uses. type WebsocketConn interface { WriteMessage(messageType int, data []byte) error + WriteControl(messageType int, data []byte, deadline time.Time) error ReadMessage() (messageType int, data []byte, err error) Close() error SetWriteDeadline(t time.Time) error diff --git a/agent/wsclient/wsconn/mock/conn.go b/agent/wsclient/wsconn/mock/conn.go index abbcb63174d..6c600d2cf9a 100644 --- a/agent/wsclient/wsconn/mock/conn.go +++ b/agent/wsclient/wsconn/mock/conn.go @@ -106,6 +106,20 @@ func (mr *MockWebsocketConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockWebsocketConn)(nil).SetWriteDeadline), arg0) } +// WriteControl mocks base method +func (m *MockWebsocketConn) WriteControl(arg0 int, arg1 []byte, arg2 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteControl", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteControl indicates an expected call of WriteControl +func (mr *MockWebsocketConnMockRecorder) WriteControl(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteControl", reflect.TypeOf((*MockWebsocketConn)(nil).WriteControl), arg0, arg1, arg2) +} + // WriteMessage mocks base method func (m *MockWebsocketConn) WriteMessage(arg0 int, arg1 []byte) error { m.ctrl.T.Helper()