Skip to content

Commit

Permalink
send pending acks before closing ACS connection
Browse files Browse the repository at this point in the history
  • Loading branch information
singholt committed Feb 23, 2023
1 parent 6d54c7c commit 6bf8301
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 4 deletions.
55 changes: 51 additions & 4 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"time"

acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
Expand Down Expand Up @@ -75,6 +76,10 @@ const (
// 1: default protocol version
// 2: ACS will proactively close the connection when heartbeat acks are missing
acsProtocolVersion = 2
// numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across
// sessions. We use this to send pending acks, before agent initiates a disconnect to ACS.
// they are: refreshCredentialsHandler, taskManifestHandler, payloadHandler and heartbeatHandler
numOfHandlersSendingAcks = 4
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -369,8 +374,10 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
}

seelog.Info("Connected to ACS endpoint")
// Start a connection timer; agent will close its ACS websocket connection after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter)
// Start a connection timer; agent will send pending acks and close its ACS websocket connection
// after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter,
&refreshCredsHandler, &taskManifestHandler, &payloadHandler, &heartbeatHandler)
defer connectionTimer.Stop()

// Start a heartbeat timer for closing the connection
Expand Down Expand Up @@ -505,10 +512,50 @@ func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitt
return timer
}

// newConnectionTimer creates a new timer, after which agent closes its ACS websocket connection
func newConnectionTimer(client wsclient.ClientServer, connectionTime time.Duration, connectionJitter time.Duration) ttime.Timer {
// newConnectionTimer creates a new timer, after which agent sends any pending acks and closes its ACS websocket connection
func newConnectionTimer(
client wsclient.ClientServer,
connectionTime time.Duration,
connectionJitter time.Duration,
refreshCredsHandler *refreshCredentialsHandler,
taskManifestHandler *taskManifestHandler,
payloadHandler *payloadRequestHandler,
heartbeatHandler *heartbeatHandler,
) ttime.Timer {
expiresAt := retry.AddJitter(connectionTime, connectionJitter)
timer := time.AfterFunc(expiresAt, func() {
seelog.Debugf("Sending pending acks to ACS before closing the connection")
var wg sync.WaitGroup
wg.Add(numOfHandlersSendingAcks)

// send pending creds refresh acks
go func() {
refreshCredsHandler.sendPendingAcks()
wg.Done()
}()

// send pending task manifest acks and task stop verification acks
go func() {
taskManifestHandler.sendPendingTaskManifestMessageAck()
taskManifestHandler.handlePendingTaskStopVerificationAck()
wg.Done()
}()

// send pending payload acks
go func() {
payloadHandler.sendPendingAcks()
wg.Done()
}()

// send pending heartbeat acks
go func() {
heartbeatHandler.sendPendingHeartbeatAck()
wg.Done()
}()

// wait for acks from all handlers above to be sent before closing the websocket connection
wg.Wait()

seelog.Infof("Closing ACS websocket connection after %v minutes", expiresAt.Minutes())
// WriteCloseMessage() writes a close message using websocket control messages
// Ref: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages
Expand Down
12 changes: 12 additions & 0 deletions agent/acs/handler/heartbeat_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() {
}
}

// sendPendingHeartbeatAck sends all pending heartbeat acks to ACS before closing the connection
func (heartbeatHandler *heartbeatHandler) sendPendingHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
default:
return
}
}
}

func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) {
err := heartbeatHandler.acsClient.MakeRequest(ack)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions agent/acs/handler/payload_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ func (payloadHandler *payloadRequestHandler) sendAcks() {
}
}

// sendPendingAcks sends ack requests to ACS before closing the connection
func (payloadHandler *payloadRequestHandler) sendPendingAcks() {
for {
select {
case mid := <-payloadHandler.ackRequest:
payloadHandler.ackMessageId(mid)
default:
return
}
}
}

// ackMessageId sends an AckRequest for a message id
func (payloadHandler *payloadRequestHandler) ackMessageId(messageID string) {
seelog.Debugf("Acking payload message id: %s", messageID)
Expand Down
12 changes: 12 additions & 0 deletions agent/acs/handler/refresh_credentials_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ func (refreshHandler *refreshCredentialsHandler) sendAcks() {
}
}

// sendPendingAcks sends pending acks to ACS before closing the connection
func (refreshHandler *refreshCredentialsHandler) sendPendingAcks() {
for {
select {
case ack := <-refreshHandler.ackRequest:
refreshHandler.ackMessage(ack)
default:
return
}
}
}

// ackMessageId sends an IAMRoleCredentialsAckRequest to the backend
func (refreshHandler *refreshCredentialsHandler) ackMessage(ack *ecsacs.IAMRoleCredentialsAckRequest) {
err := refreshHandler.acsClient.MakeRequest(ack)
Expand Down
27 changes: 27 additions & 0 deletions agent/acs/handler/task_manifest_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ func (taskManifestHandler *taskManifestHandler) sendTaskManifestMessageAck() {
}
}

// sendPendingTaskManifestMessageAck sends all pending task manifest acks to ACS before closing the connection
func (taskManifestHandler *taskManifestHandler) sendPendingTaskManifestMessageAck() {
for {
select {
case messageBufferTaskManifestAck := <-taskManifestHandler.messageBufferTaskManifestAck:
taskManifestHandler.ackTaskManifestMessage(messageBufferTaskManifestAck)
default:
return
}
}
}

func (taskManifestHandler *taskManifestHandler) handleTaskStopVerificationAck() {
for {
select {
Expand All @@ -130,6 +142,21 @@ func (taskManifestHandler *taskManifestHandler) handleTaskStopVerificationAck()
}
}

// handlePendingTaskStopVerificationAck sends pending task stop verification acks to ACS before closing the connection
func (taskManifestHandler *taskManifestHandler) handlePendingTaskStopVerificationAck() {
for {
select {
case messageBufferTaskStopVerificationAck := <-taskManifestHandler.messageBufferTaskStopVerificationAck:
if err := taskManifestHandler.handleSingleMessageVerificationAck(messageBufferTaskStopVerificationAck); err != nil {
seelog.Warnf("Error handling Verification ack with messageID: %s, error: %v",
messageBufferTaskStopVerificationAck.MessageId, err)
}
default:
return
}
}
}

func (taskManifestHandler *taskManifestHandler) clearAcks() {
for {
select {
Expand Down

0 comments on commit 6bf8301

Please sign in to comment.