Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ECS agent to acknowledge server heartbeat messages #2837

Merged
merged 1 commit into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agent/acs/client/acs_client_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func init() {
// the .json model or the generated struct names.
acsRecognizedTypes = []interface{}{
ecsacs.HeartbeatMessage{},
ecsacs.HeartbeatAckRequest{},
ecsacs.PayloadMessage{},
ecsacs.CloseMessage{},
ecsacs.AckRequest{},
Expand Down
15 changes: 12 additions & 3 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"time"

acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
"github.com/aws/amazon-ecs-agent/agent/api"
"github.com/aws/amazon-ecs-agent/agent/config"
Expand Down Expand Up @@ -65,6 +64,10 @@ const (
// credentials for all tasks on establishing the connection
sendCredentialsURLParameterName = "sendCredentials"
inactiveInstanceExceptionPrefix = "InactiveInstanceException:"
// ACS protocol version spec:
// 1: default protocol version
// 2: ACS will proactively close the connection when heartbeat acks are missing
acsProtocolVersion = 2
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -332,8 +335,13 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {

client.AddRequestHandler(payloadHandler.handlerFunc())

// Ignore heartbeat messages; anyMessageHandler gets 'em
client.AddRequestHandler(func(*ecsacs.HeartbeatMessage) {})
// Add HeartbeatHandler to acknowledge ACS heartbeats
heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client)
defer heartbeatHandler.clearAcks()
sparrc marked this conversation as resolved.
Show resolved Hide resolved
heartbeatHandler.start()
defer heartbeatHandler.stop()

client.AddRequestHandler(heartbeatHandler.handlerFunc())

updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine)

Expand Down Expand Up @@ -454,6 +462,7 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.
query.Set("agentHash", version.GitHashString())
query.Set("agentVersion", version.Version)
query.Set("seqNum", "1")
query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion))
if dockerVersion, err := taskEngine.Version(); err == nil {
query.Set("dockerVersion", "DockerVersion: "+dockerVersion)
}
Expand Down
15 changes: 9 additions & 6 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"reflect"
"runtime"
"runtime/pprof"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -178,6 +179,8 @@ func TestACSWSURL(t *testing.T) {
assert.Equal(t, "DockerVersion: Docker version result", parsed.Query().Get("dockerVersion"), "wrong docker version")
assert.Equalf(t, "true", parsed.Query().Get(sendCredentialsURLParameterName), "Wrong value set for: %s", sendCredentialsURLParameterName)
assert.Equal(t, "1", parsed.Query().Get("seqNum"), "wrong seqNum")
protocolVersion, _ := strconv.Atoi(parsed.Query().Get("protocolVersion"))
assert.True(t, protocolVersion > 1, "ACS protocol version should be greater than 1")
}

// TestHandlerReconnectsOnConnectErrors tests if handler reconnects retries
Expand Down Expand Up @@ -844,12 +847,12 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {
ended <- true
}()
// Warm it up
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}`
serverIn <- samplePayloadMessage

beforeGoroutines := runtime.NumGoroutine()
for i := 0; i < 100; i++ {
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
for i := 0; i < 40; i++ {
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}`
serverIn <- samplePayloadMessage
closeWS <- true
}
Expand All @@ -859,15 +862,15 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {

// The number of goroutines finishing in the MockACSServer will affect
// the result unless we wait here.
time.Sleep(10 * time.Millisecond)
time.Sleep(1 * time.Second)
afterGoroutines := runtime.NumGoroutine()

t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines)

if timesConnected < 50 {
if timesConnected < 20 {
t.Fatal("Expected times connected to be a large number, was ", timesConnected)
}
if afterGoroutines > beforeGoroutines+5 {
if afterGoroutines > beforeGoroutines+2 {
t.Error("Goroutine leak, oh no!")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
}
Expand Down
118 changes: 118 additions & 0 deletions agent/acs/handler/heartbeat_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package handler

import (
"context"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
"github.com/aws/aws-sdk-go/aws"
"github.com/cihub/seelog"
)

// heartbeatHandler handles heartbeat messages from ACS
type heartbeatHandler struct {
heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage
heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest
ctx context.Context
cancel context.CancelFunc
acsClient wsclient.ClientServer
}

// newHeartbeatHandler returns an instance of the heartbeatHandler struct
func newHeartbeatHandler(ctx context.Context,
acsClient wsclient.ClientServer) heartbeatHandler {

// Create a cancelable context from the parent context
derivedContext, cancel := context.WithCancel(ctx)
return heartbeatHandler{
heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage),
heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest),
ctx: derivedContext,
cancel: cancel,
acsClient: acsClient,
}
}

// handlerFunc returns a function to enqueue requests onto the buffer
func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) {
return func(message *ecsacs.HeartbeatMessage) {
heartbeatHandler.heartbeatMessageBuffer <- message
}
}

// start() invokes go routines to handle receive and respond to heartbeats
func (heartbeatHandler *heartbeatHandler) start() {
go heartbeatHandler.handleHeartbeatMessage()
go heartbeatHandler.sendHeartbeatAck()
}

func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() {
for {
select {
case message := <-heartbeatHandler.heartbeatMessageBuffer:
if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil {
seelog.Warnf("Unable to handle heartbeat message [%s]: %s", message.String(), err)
}
case <-heartbeatHandler.ctx.Done():
return
}
}
}

func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error {
// Agent currently has no other action hooked to heartbeat messages, except simple ack
go func() {
response := &ecsacs.HeartbeatAckRequest{
MessageId: message.MessageId,
}
heartbeatHandler.heartbeatAckMessageBuffer <- response
}()
return nil
}

func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
case <-heartbeatHandler.ctx.Done():
return
}
}
}

func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) {
err := heartbeatHandler.acsClient.MakeRequest(ack)
sparrc marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err)
}
}

// stop() cancels the context being used by this handler, which stops the go routines started by 'start()'
func (heartbeatHandler *heartbeatHandler) stop() {
heartbeatHandler.cancel()
}

// clearAcks drains the ack request channel
func (heartbeatHandler *heartbeatHandler) clearAcks() {
for {
select {
case <-heartbeatHandler.heartbeatAckMessageBuffer:
default:
return
}
}
}
104 changes: 104 additions & 0 deletions agent/acs/handler/heartbeat_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// +build unit

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package handler

import (
"context"
"testing"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"
"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)

const (
heartbeatMessageId = "heartbeatMessageId"
)

func TestAckHeartbeatMessage(t *testing.T) {
heartbeatReceived := &ecsacs.HeartbeatMessage{
MessageId: aws.String(heartbeatMessageId),
Healthy: aws.Bool(true),
}

heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
MessageId: aws.String(heartbeatMessageId),
}

validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
}

func TestAckHeartbeatMessageNotHealthy(t *testing.T) {
heartbeatReceived := &ecsacs.HeartbeatMessage{
MessageId: aws.String(heartbeatMessageId),
// ECS Agent currently ignores this field so we expect no behavior change
Healthy: aws.Bool(false),
}

heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
MessageId: aws.String(heartbeatMessageId),
}

validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
}

func TestAckHeartbeatMessageWithoutMessageId(t *testing.T) {
heartbeatReceived := &ecsacs.HeartbeatMessage{
Healthy: aws.Bool(true),
}

heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
MessageId: nil,
}

validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
}

func TestAckHeartbeatMessageEmpty(t *testing.T) {
heartbeatReceived := &ecsacs.HeartbeatMessage{}

heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
MessageId: nil,
}

validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
}

func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessage, heartbeatAckExpected *ecsacs.HeartbeatAckRequest) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx, cancel := context.WithCancel(context.Background())
var heartbeatAckSent *ecsacs.HeartbeatAckRequest

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) {
heartbeatAckSent = message
cancel()
}).Times(1)

handler := newHeartbeatHandler(ctx, mockWsClient)
go handler.sendHeartbeatAck()

handler.handleSingleHeartbeatMessage(heartbeatReceived)

// wait till we get an ack from heartbeatAckMessageBuffer
<-ctx.Done()

require.Equal(t, heartbeatAckExpected, heartbeatAckSent)
}
12 changes: 10 additions & 2 deletions agent/acs/model/api/api-2.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
"requestUri":"/"
},
"input":{"shape":"HeartbeatMessage"},
"documentation":"Heartbeat is a periodic message that informs the agent all is well."
"output":{"shape":"HeartbeatAckRequest"},
"documentation":"Heartbeat is a periodic message between the Agent and ECS backend to keep the connection alive."
},
"Payload":{
"name":"Payload",
Expand Down Expand Up @@ -417,7 +418,14 @@
"HeartbeatMessage":{
"type":"structure",
"members":{
"healthy":{"shape":"Boolean"}
"healthy":{"shape":"Boolean"},
"messageId":{"shape":"String"}
}
},
"HeartbeatAckRequest":{
"type":"structure",
"members":{
"messageId":{"shape":"String"}
}
},
"HostVolumeProperties":{
Expand Down
22 changes: 22 additions & 0 deletions agent/acs/model/ecsacs/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.