Skip to content

Commit

Permalink
ECS agent to acknowledge server heartbeat messages
Browse files Browse the repository at this point in the history
  • Loading branch information
LiangChiAmzn committed Apr 21, 2021
1 parent a6289db commit 9390643
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 11 deletions.
1 change: 1 addition & 0 deletions agent/acs/client/acs_client_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func init() {
// the .json model or the generated struct names.
acsRecognizedTypes = []interface{}{
ecsacs.HeartbeatMessage{},
ecsacs.HeartbeatAckRequest{},
ecsacs.PayloadMessage{},
ecsacs.CloseMessage{},
ecsacs.AckRequest{},
Expand Down
15 changes: 12 additions & 3 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"time"

acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
"github.com/aws/amazon-ecs-agent/agent/api"
"github.com/aws/amazon-ecs-agent/agent/config"
Expand Down Expand Up @@ -65,6 +64,10 @@ const (
// credentials for all tasks on establishing the connection
sendCredentialsURLParameterName = "sendCredentials"
inactiveInstanceExceptionPrefix = "InactiveInstanceException:"
// ACS protocol version spec:
// 1: default protocol version
// 2: ACS will proactively close the connection when heartbeat acks are missing
acsProtocolVersion = 2
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -332,8 +335,13 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {

client.AddRequestHandler(payloadHandler.handlerFunc())

// Ignore heartbeat messages; anyMessageHandler gets 'em
client.AddRequestHandler(func(*ecsacs.HeartbeatMessage) {})
// Add HeartbeatHandler to acknowledge ACS heartbeats
heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client)
defer heartbeatHandler.clearAcks()
heartbeatHandler.start()
defer heartbeatHandler.stop()

client.AddRequestHandler(heartbeatHandler.handlerFunc())

updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine)

Expand Down Expand Up @@ -454,6 +462,7 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.
query.Set("agentHash", version.GitHashString())
query.Set("agentVersion", version.Version)
query.Set("seqNum", "1")
query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion))
if dockerVersion, err := taskEngine.Version(); err == nil {
query.Set("dockerVersion", "DockerVersion: "+dockerVersion)
}
Expand Down
15 changes: 9 additions & 6 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"reflect"
"runtime"
"runtime/pprof"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -178,6 +179,8 @@ func TestACSWSURL(t *testing.T) {
assert.Equal(t, "DockerVersion: Docker version result", parsed.Query().Get("dockerVersion"), "wrong docker version")
assert.Equalf(t, "true", parsed.Query().Get(sendCredentialsURLParameterName), "Wrong value set for: %s", sendCredentialsURLParameterName)
assert.Equal(t, "1", parsed.Query().Get("seqNum"), "wrong seqNum")
protocolVersion, _ := strconv.Atoi(parsed.Query().Get("protocolVersion"))
assert.True(t, protocolVersion > 1, "ACS protocol version should be greater than 1")
}

// TestHandlerReconnectsOnConnectErrors tests if handler reconnects retries
Expand Down Expand Up @@ -844,12 +847,12 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {
ended <- true
}()
// Warm it up
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}`
serverIn <- samplePayloadMessage

beforeGoroutines := runtime.NumGoroutine()
for i := 0; i < 100; i++ {
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
for i := 0; i < 40; i++ {
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}`
serverIn <- samplePayloadMessage
closeWS <- true
}
Expand All @@ -859,15 +862,15 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {

// The number of goroutines finishing in the MockACSServer will affect
// the result unless we wait here.
time.Sleep(10 * time.Millisecond)
time.Sleep(1 * time.Second)
afterGoroutines := runtime.NumGoroutine()

t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines)

if timesConnected < 50 {
if timesConnected < 20 {
t.Fatal("Expected times connected to be a large number, was ", timesConnected)
}
if afterGoroutines > beforeGoroutines+5 {
if afterGoroutines > beforeGoroutines+2 {
t.Error("Goroutine leak, oh no!")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
}
Expand Down
119 changes: 119 additions & 0 deletions agent/acs/handler/heartbeat_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package handler

import (
"context"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
"github.com/aws/aws-sdk-go/aws"
"github.com/cihub/seelog"
)

// heartbeatHandler handles heartbeat messages from ACS
type heartbeatHandler struct {
heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage
heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest
ctx context.Context
cancel context.CancelFunc
acsClient wsclient.ClientServer
}

// newHeartbeatHandler returns an instance of the heartbeatHandler struct
func newHeartbeatHandler(ctx context.Context,
acsClient wsclient.ClientServer) heartbeatHandler {

// Create a cancelable context from the parent context
derivedContext, cancel := context.WithCancel(ctx)
return heartbeatHandler{
heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage),
heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest),
ctx: derivedContext,
cancel: cancel,
acsClient: acsClient,
}
}

// handlerFunc returns a function to enqueue requests onto the buffer
func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) {
return func(message *ecsacs.HeartbeatMessage) {
heartbeatHandler.heartbeatMessageBuffer <- message
}
}

// start() invokes go routines to handle receive and respond to heartbeats
func (heartbeatHandler *heartbeatHandler) start() {
go heartbeatHandler.handleHeartbeatMessage()
go heartbeatHandler.sendHeartbeatAck()
}

func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() {
for {
select {
case message := <-heartbeatHandler.heartbeatMessageBuffer:
if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil {
seelog.Warnf("Unable to handle heartbeat message [%s]: %v", message.String(), err)
}
case <-heartbeatHandler.ctx.Done():
return
}
}
}

func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error {
seelog.Tracef("Received server heartbeat message: %s", message.MessageId)
go func() {
response := &ecsacs.HeartbeatAckRequest{
MessageId: message.MessageId,
}
heartbeatHandler.heartbeatAckMessageBuffer <- response
}()
return nil
}

func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
case <-heartbeatHandler.ctx.Done():
return
}
}
}

func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) {
err := heartbeatHandler.acsClient.MakeRequest(ack)
if err != nil {
seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %v", aws.StringValue(ack.MessageId), err)
}
seelog.Tracef("Acknowledging server heartbeat message: %s", ack.MessageId)
}

// stop() cancels the context being used by this handler, which stops the go routines started by 'start()'
func (heartbeatHandler *heartbeatHandler) stop() {
heartbeatHandler.cancel()
}

// clearAcks drains the ack request channel
func (heartbeatHandler *heartbeatHandler) clearAcks() {
for {
select {
case <-heartbeatHandler.heartbeatAckMessageBuffer:
default:
return
}
}
}
95 changes: 95 additions & 0 deletions agent/acs/handler/heartbeat_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// +build unit

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package handler

import (
"context"
"reflect"
"testing"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"
"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
)

const (
heartbeatMessageId = "heartbeatMessageId"
heartbeatHealthy = true
)

func TestAckHeartbeatMessage(t *testing.T) {
heartbeatReceived := &ecsacs.HeartbeatMessage{
MessageId: aws.String(heartbeatMessageId),
Healthy: aws.Bool(heartbeatHealthy),
}

heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
MessageId: aws.String(heartbeatMessageId),
}

validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
}

func TestAckHeartbeatMessageWithoutMessageId(t *testing.T) {
heartbeatReceived := &ecsacs.HeartbeatMessage{
Healthy: aws.Bool(heartbeatHealthy),
}

heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
MessageId: nil,
}

validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
}

func TestAckHeartbeatMessageEmpty(t *testing.T) {
heartbeatReceived := &ecsacs.HeartbeatMessage{}

heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
MessageId: nil,
}

validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
}

func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessage, heartbeatAckExpected *ecsacs.HeartbeatAckRequest) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
var heartbeatAckSent *ecsacs.HeartbeatAckRequest

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) {
heartbeatAckSent = message
cancel()
}).Times(1)

handler := newHeartbeatHandler(ctx, mockWsClient)
go handler.sendHeartbeatAck()

handler.handleSingleHeartbeatMessage(heartbeatReceived)

// wait till we get an ack from heartbeatAckMessageBuffer
select {
case <-ctx.Done():
}

if !reflect.DeepEqual(heartbeatAckExpected, heartbeatAckSent) {
t.Errorf("Message mismatch between expected and sent ack, expected: %v, sent: %v", heartbeatAckExpected, heartbeatAckSent)
}
}
12 changes: 10 additions & 2 deletions agent/acs/model/api/api-2.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
"requestUri":"/"
},
"input":{"shape":"HeartbeatMessage"},
"documentation":"Heartbeat is a periodic message that informs the agent all is well."
"output":{"shape":"HeartbeatAckRequest"},
"documentation":"Heartbeat is a periodic message between the Agent and ECS backend to keep the connection alive."
},
"Payload":{
"name":"Payload",
Expand Down Expand Up @@ -417,7 +418,14 @@
"HeartbeatMessage":{
"type":"structure",
"members":{
"healthy":{"shape":"Boolean"}
"healthy":{"shape":"Boolean"},
"messageId":{"shape":"String"}
}
},
"HeartbeatAckRequest":{
"type":"structure",
"members":{
"messageId":{"shape":"String"}
}
},
"HostVolumeProperties":{
Expand Down
22 changes: 22 additions & 0 deletions agent/acs/model/ecsacs/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9390643

Please sign in to comment.