Skip to content

Commit

Permalink
periodically disconnect from acs
Browse files Browse the repository at this point in the history
  • Loading branch information
singholt committed Feb 22, 2023
1 parent 7d2b6da commit 67ca497
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 22 deletions.
48 changes: 38 additions & 10 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/aws/amazon-ecs-agent/agent/utils/ttime"
"github.com/aws/amazon-ecs-agent/agent/version"
"github.com/aws/amazon-ecs-agent/agent/wsclient"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/cihub/seelog"
)
Expand All @@ -54,6 +55,10 @@ const (

inactiveInstanceReconnectDelay = 1 * time.Hour

// connectionTime is the maximum time after which agent closes its connection to ACS
connectionTime = 15 * time.Minute
connectionJitter = 30 * time.Minute

connectionBackoffMin = 250 * time.Millisecond
connectionBackoffMax = 2 * time.Minute
connectionBackoffJitter = 0.2
Expand Down Expand Up @@ -100,14 +105,16 @@ type session struct {
doctor *doctor.Doctor
_heartbeatTimeout time.Duration
_heartbeatJitter time.Duration
connectionTime time.Duration
connectionJitter time.Duration
_inactiveInstanceReconnectDelay time.Duration
}

// sessionResources defines the resource creator interface for starting
// a session with ACS. This interface is intended to define methods
// that create resources used to establish the connection to ACS
// It is confined to just the createACSClient() method for now. It can be
// extended to include the acsWsURL() and newDisconnectionTimer() methods
// extended to include the acsWsURL() and newHeartbeatTimer() methods
// when needed
// The goal is to make it easier to test and inject dependencies
type sessionResources interface {
Expand Down Expand Up @@ -182,6 +189,8 @@ func NewSession(
doctor: doctor,
_heartbeatTimeout: heartbeatTimeout,
_heartbeatJitter: heartbeatJitter,
connectionTime: connectionTime,
connectionJitter: connectionJitter,
_inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay,
}
}
Expand Down Expand Up @@ -224,7 +233,7 @@ func (acsSession *session) Start() error {
}
}
if shouldReconnectWithoutBackoff(acsError) {
// If ACS closed the connection, there's no need to backoff,
// If ACS or agent closed the connection, there's no need to backoff,
// reconnect immediately
seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError)
acsSession.backoff.Reset()
Expand Down Expand Up @@ -360,11 +369,15 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
}

seelog.Info("Connected to ACS endpoint")
// Start inactivity timer for closing the connection
timer := newDisconnectionTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter())
// Any message from the server resets the disconnect timeout
client.SetAnyRequestHandler(anyMessageHandler(timer, client))
defer timer.Stop()
// Start a connection timer; agent will close its ACS websocket connection after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter)
defer connectionTimer.Stop()

// Start a heartbeat timer for closing the connection
heartbeatTimer := newHeartbeatTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter())
// Any message from the server resets the heartbeat timer
client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client))
defer heartbeatTimer.Stop()

acsSession.resources.connectedToACS()

Expand Down Expand Up @@ -393,7 +406,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
case err := <-serveErr:
// Stop receiving and sending messages from and to ACS when
// client.Serve returns an error. This can happen when the
// the connection is closed by ACS or the agent
// connection is closed by ACS or the agent
if err == nil || err == io.EOF {
seelog.Info("ACS Websocket connection closed for a valid reason")
} else {
Expand Down Expand Up @@ -478,9 +491,9 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.
return acsURL + "?" + query.Encode()
}

// newDisconnectionTimer creates a new time object, with a callback to
// newHeartbeatTimer creates a new time object, with a callback to
// disconnect from ACS on inactivity
func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer {
func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer {
timer := time.AfterFunc(retry.AddJitter(timeout, jitter), func() {
seelog.Warn("ACS Connection hasn't had any activity for too long; closing connection")
if err := client.Close(); err != nil {
Expand All @@ -492,6 +505,21 @@ func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration,
return timer
}

// newConnectionTimer creates a new timer, after which agent closes its ACS websocket connection
func newConnectionTimer(client wsclient.ClientServer, connectionTime time.Duration, connectionJitter time.Duration) ttime.Timer {
expiresAt := retry.AddJitter(connectionTime, connectionJitter)
timer := time.AfterFunc(expiresAt, func() {
seelog.Infof("Closing ACS websocket connection after %v minutes", expiresAt.Minutes())
// WriteCloseMessage() writes a close message using websocket control messages
// Ref: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages
err := client.WriteCloseMessage()
if err != nil {
seelog.Warnf("Error writing close message: %v", err)
}
})
return timer
}

// anyMessageHandler handles any server message. Any server message means the
// connection is active and thus the heartbeat disconnect should not occur
func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) {
Expand Down
81 changes: 73 additions & 8 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package handler

import (
"context"
"fmt"
"io"
"net/http"
Expand All @@ -30,8 +31,6 @@ import (
"testing"
"time"

"context"

apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container"
mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks"
apitask "github.com/aws/amazon-ecs-agent/agent/api/task"
Expand All @@ -45,19 +44,18 @@ import (
mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks"
"github.com/aws/amazon-ecs-agent/agent/eventhandler"
"github.com/aws/amazon-ecs-agent/agent/eventstream"

"github.com/aws/amazon-ecs-agent/agent/utils/retry"
mock_retry "github.com/aws/amazon-ecs-agent/agent/utils/retry/mock"
"github.com/aws/amazon-ecs-agent/agent/version"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/golang/mock/gomock"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"

"github.com/golang/mock/gomock"
)

const (
Expand Down Expand Up @@ -210,7 +208,7 @@ func TestHandlerReconnectsOnConnectErrors(t *testing.T) {
mockWsClient.EXPECT().Connect().Return(io.EOF).Times(10),
// Cancel trying to connect to ACS on the 11th attempt
// Failure to retry on Connect() errors should cause the
// test to time out as the context is never cancelled
// test to time out as the context is never cancelled
mockWsClient.EXPECT().Connect().Do(func() {
cancel()
}).Return(nil).MinTimes(1),
Expand Down Expand Up @@ -575,7 +573,7 @@ func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) {
}

// TestHandlerReconnectsOnServeErrors tests if the handler retries to
// to establish the session with ACS when ClientServer.Connect() returns errors
// establish the session with ACS when ClientServer.Serve() returns errors
func TestHandlerReconnectsOnServeErrors(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand All @@ -592,6 +590,7 @@ func TestHandlerReconnectsOnServeErrors(t *testing.T) {
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
gomock.InOrder(
// Serve fails 10 times
Expand Down Expand Up @@ -647,6 +646,7 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) {
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().Connect().Return(nil).AnyTimes()
mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
gomock.InOrder(
mockWsClient.EXPECT().Serve().Return(io.EOF),
Expand All @@ -668,6 +668,8 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) {
resources: &mockSessionResources{mockWsClient},
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
connectionTime: 30 * time.Millisecond,
connectionJitter: 10 * time.Millisecond,
}

// The session error channel would have an event when the Start() method returns
Expand Down Expand Up @@ -695,6 +697,7 @@ func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) {
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().Serve().AnyTimes()
mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
mockWsClient.EXPECT().Connect().Do(func() {
// Serve() cancels the context
Expand All @@ -721,6 +724,8 @@ func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) {
resources: &mockSessionResources{mockWsClient},
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
connectionTime: 30 * time.Millisecond,
connectionJitter: 10 * time.Millisecond,
}
go func() {
acsSession.Start()
Expand Down Expand Up @@ -771,7 +776,7 @@ func TestConnectionIsClosedOnIdle(t *testing.T) {
// been breached while Serving requests
time.Sleep(30 * time.Millisecond)
}).Return(io.EOF)

mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes()
connectionClosed := make(chan bool)
mockWsClient.EXPECT().Close().Do(func() {
wait.Wait()
Expand All @@ -791,6 +796,8 @@ func TestConnectionIsClosedOnIdle(t *testing.T) {
resources: &mockSessionResources{},
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
connectionTime: 30 * time.Millisecond,
connectionJitter: 10 * time.Millisecond,
}
go acsSession.startACSSession(mockWsClient)

Expand All @@ -799,6 +806,61 @@ func TestConnectionIsClosedOnIdle(t *testing.T) {
<-connectionClosed
}

// TestConnectionIsClosedAfterTimeIsUp tests if the connection to ACS is closed
// when the session's connection time is expired.
func TestConnectionIsClosedAfterTimeIsUp(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes()

ecsClient := mock_api.NewMockECSClient(ctrl)
ctx, cancel := context.WithCancel(context.Background())
taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil)
defer cancel()

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes()
mockWsClient.EXPECT().Connect().Return(nil)
mockWsClient.EXPECT().Serve().Do(func() {
// pretend as if the connectionTime has elapsed
time.Sleep(30 * time.Millisecond)
cancel()
}).Return(io.EOF)
mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes()

// set connectionTime to a value lower than the heartbeatTimeout to avoid
// closing the connection due to the heartbeatTimer's callback func
acsSession := session{
containerInstanceARN: "myArn",
credentialsProvider: testCreds,
agentConfig: testConfig,
taskEngine: taskEngine,
ecsClient: ecsClient,
dataClient: data.NewNoopClient(),
taskHandler: taskHandler,
ctx: context.Background(),
backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier),
resources: &mockSessionResources{},
_heartbeatTimeout: 50 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
connectionTime: 20 * time.Millisecond,
connectionJitter: 10 * time.Millisecond,
}

go func() {
messageError := make(chan error)
messageError <- acsSession.startACSSession(mockWsClient)
assert.EqualError(t, <-messageError, io.EOF.Error())
}()

// Wait for context to be cancelled
select {
case <-ctx.Done():
}
}

func TestHandlerDoesntLeakGoroutines(t *testing.T) {
// Skip this test on "windows" platform as we have observed this to
// fail often after upgrading the windows builds to golang v1.17.
Expand Down Expand Up @@ -1036,6 +1098,7 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T)
mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes()
mockWsClient.EXPECT().WriteCloseMessage().AnyTimes()
mockWsClient.EXPECT().Close().Return(nil).AnyTimes()
mockWsClient.EXPECT().Serve().Return(io.EOF).AnyTimes()

Expand Down Expand Up @@ -1068,6 +1131,8 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T)
backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier),
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
connectionTime: 30 * time.Millisecond,
connectionJitter: 10 * time.Millisecond,
}
go func() {
for i := 0; i < 10; i++ {
Expand Down
20 changes: 16 additions & 4 deletions agent/wsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package wsclient

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
Expand All @@ -33,14 +34,12 @@ import (
"sync"
"time"

"github.com/aws/amazon-ecs-agent/agent/logger"

"crypto/tls"

"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/logger"
"github.com/aws/amazon-ecs-agent/agent/utils"
"github.com/aws/amazon-ecs-agent/agent/utils/cipher"
"github.com/aws/amazon-ecs-agent/agent/wsclient/wsconn"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
"github.com/cihub/seelog"
Expand Down Expand Up @@ -98,6 +97,7 @@ type ClientServer interface {
SetAnyRequestHandler(RequestHandler)
MakeRequest(input interface{}) error
WriteMessage(input []byte) error
WriteCloseMessage() error
Connect() error
IsConnected() bool
SetConnection(conn wsconn.WebsocketConn)
Expand Down Expand Up @@ -370,6 +370,18 @@ func (cs *ClientServerImpl) WriteMessage(send []byte) error {
return cs.conn.WriteMessage(websocket.TextMessage, send)
}

// WriteCloseMessage wraps the low level websocket WriteControl method with a lock, and sends a message of type
// CloseMessage (Ref: https://github.com/gorilla/websocket/blob/9111bb834a68b893cebbbaed5060bdbc1d9ab7d2/conn.go#L74)
func (cs *ClientServerImpl) WriteCloseMessage() error {
cs.writeLock.Lock()
defer cs.writeLock.Unlock()

send := websocket.FormatCloseMessage(websocket.CloseNormalClosure,
"ConnectionExpired: Reconnect to continue")

return cs.conn.WriteControl(websocket.CloseMessage, send, time.Now().Add(cs.RWTimeout))
}

// ConsumeMessages reads messages from the websocket connection and handles read
// messages from an active connection.
func (cs *ClientServerImpl) ConsumeMessages() error {
Expand Down
14 changes: 14 additions & 0 deletions agent/wsclient/mock/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions agent/wsclient/wsconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import "time"
// connection's methods that this client uses.
type WebsocketConn interface {
WriteMessage(messageType int, data []byte) error
WriteControl(messageType int, data []byte, deadline time.Time) error
ReadMessage() (messageType int, data []byte, err error)
Close() error
SetWriteDeadline(t time.Time) error
Expand Down
Loading

0 comments on commit 67ca497

Please sign in to comment.