diff --git a/agent/acs/client/acs_client_types.go b/agent/acs/client/acs_client_types.go index 046f3ce119..eaa6e2a40e 100644 --- a/agent/acs/client/acs_client_types.go +++ b/agent/acs/client/acs_client_types.go @@ -30,6 +30,7 @@ func init() { // the .json model or the generated struct names. acsRecognizedTypes = []interface{}{ ecsacs.HeartbeatMessage{}, + ecsacs.HeartbeatAckRequest{}, ecsacs.PayloadMessage{}, ecsacs.CloseMessage{}, ecsacs.AckRequest{}, diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index b3b7e18132..a1c439be55 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -24,7 +24,6 @@ import ( "time" acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client" - "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler" "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/config" @@ -65,6 +64,10 @@ const ( // credentials for all tasks on establishing the connection sendCredentialsURLParameterName = "sendCredentials" inactiveInstanceExceptionPrefix = "InactiveInstanceException:" + // ACS protocol version spec: + // 1: default protocol version + // 2: ACS will proactively close the connection when heartbeat acks are missing + acsProtocolVersion = 2 ) // Session defines an interface for handler's long-lived connection with ACS. @@ -332,8 +335,13 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { client.AddRequestHandler(payloadHandler.handlerFunc()) - // Ignore heartbeat messages; anyMessageHandler gets 'em - client.AddRequestHandler(func(*ecsacs.HeartbeatMessage) {}) + // Add HeartbeatHandler to acknowledge ACS heartbeats + heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client) + defer heartbeatHandler.clearAcks() + heartbeatHandler.start() + defer heartbeatHandler.stop() + + client.AddRequestHandler(heartbeatHandler.handlerFunc()) updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine) @@ -454,6 +462,7 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine. query.Set("agentHash", version.GitHashString()) query.Set("agentVersion", version.Version) query.Set("seqNum", "1") + query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion)) if dockerVersion, err := taskEngine.Version(); err == nil { query.Set("dockerVersion", "DockerVersion: "+dockerVersion) } diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index 9cad5a4d02..3cbd383c7b 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -24,6 +24,7 @@ import ( "reflect" "runtime" "runtime/pprof" + "strconv" "sync" "testing" "time" @@ -178,6 +179,8 @@ func TestACSWSURL(t *testing.T) { assert.Equal(t, "DockerVersion: Docker version result", parsed.Query().Get("dockerVersion"), "wrong docker version") assert.Equalf(t, "true", parsed.Query().Get(sendCredentialsURLParameterName), "Wrong value set for: %s", sendCredentialsURLParameterName) assert.Equal(t, "1", parsed.Query().Get("seqNum"), "wrong seqNum") + protocolVersion, _ := strconv.Atoi(parsed.Query().Get("protocolVersion")) + assert.True(t, protocolVersion > 1, "ACS protocol version should be greater than 1") } // TestHandlerReconnectsOnConnectErrors tests if handler reconnects retries @@ -844,12 +847,12 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) { ended <- true }() // Warm it up - serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}` + serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}` serverIn <- samplePayloadMessage beforeGoroutines := runtime.NumGoroutine() - for i := 0; i < 100; i++ { - serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}` + for i := 0; i < 40; i++ { + serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}` serverIn <- samplePayloadMessage closeWS <- true } @@ -859,15 +862,15 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) { // The number of goroutines finishing in the MockACSServer will affect // the result unless we wait here. - time.Sleep(10 * time.Millisecond) + time.Sleep(1 * time.Second) afterGoroutines := runtime.NumGoroutine() t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines) - if timesConnected < 50 { + if timesConnected < 20 { t.Fatal("Expected times connected to be a large number, was ", timesConnected) } - if afterGoroutines > beforeGoroutines+5 { + if afterGoroutines > beforeGoroutines+2 { t.Error("Goroutine leak, oh no!") pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) } diff --git a/agent/acs/handler/heartbeat_handler.go b/agent/acs/handler/heartbeat_handler.go new file mode 100644 index 0000000000..dc0385a4a2 --- /dev/null +++ b/agent/acs/handler/heartbeat_handler.go @@ -0,0 +1,119 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package handler + +import ( + "context" + + "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/agent/wsclient" + "github.com/aws/aws-sdk-go/aws" + "github.com/cihub/seelog" +) + +// heartbeatHandler handles heartbeat messages from ACS +type heartbeatHandler struct { + heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage + heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest + ctx context.Context + cancel context.CancelFunc + acsClient wsclient.ClientServer +} + +// newHeartbeatHandler returns an instance of the heartbeatHandler struct +func newHeartbeatHandler(ctx context.Context, + acsClient wsclient.ClientServer) heartbeatHandler { + + // Create a cancelable context from the parent context + derivedContext, cancel := context.WithCancel(ctx) + return heartbeatHandler{ + heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage), + heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest), + ctx: derivedContext, + cancel: cancel, + acsClient: acsClient, + } +} + +// handlerFunc returns a function to enqueue requests onto the buffer +func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) { + return func(message *ecsacs.HeartbeatMessage) { + heartbeatHandler.heartbeatMessageBuffer <- message + } +} + +// start() invokes go routines to handle receive and respond to heartbeats +func (heartbeatHandler *heartbeatHandler) start() { + go heartbeatHandler.handleHeartbeatMessage() + go heartbeatHandler.sendHeartbeatAck() +} + +func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() { + for { + select { + case message := <-heartbeatHandler.heartbeatMessageBuffer: + if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil { + seelog.Warnf("Unable to handle heartbeat message [%s]: %v", message.String(), err) + } + case <-heartbeatHandler.ctx.Done(): + return + } + } +} + +func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error { + seelog.Tracef("Received server heartbeat message: %s", message.MessageId) + go func() { + response := &ecsacs.HeartbeatAckRequest{ + MessageId: message.MessageId, + } + heartbeatHandler.heartbeatAckMessageBuffer <- response + }() + return nil +} + +func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() { + for { + select { + case ack := <-heartbeatHandler.heartbeatAckMessageBuffer: + heartbeatHandler.sendSingleHeartbeatAck(ack) + case <-heartbeatHandler.ctx.Done(): + return + } + } +} + +func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) { + err := heartbeatHandler.acsClient.MakeRequest(ack) + if err != nil { + seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %v", aws.StringValue(ack.MessageId), err) + } + seelog.Tracef("Acknowledging server heartbeat message: %s", ack.MessageId) +} + +// stop() cancels the context being used by this handler, which stops the go routines started by 'start()' +func (heartbeatHandler *heartbeatHandler) stop() { + heartbeatHandler.cancel() +} + +// clearAcks drains the ack request channel +func (heartbeatHandler *heartbeatHandler) clearAcks() { + for { + select { + case <-heartbeatHandler.heartbeatAckMessageBuffer: + default: + return + } + } +} diff --git a/agent/acs/handler/heartbeat_handler_test.go b/agent/acs/handler/heartbeat_handler_test.go new file mode 100644 index 0000000000..5da37d5970 --- /dev/null +++ b/agent/acs/handler/heartbeat_handler_test.go @@ -0,0 +1,95 @@ +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package handler + +import ( + "context" + "reflect" + "testing" + + "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" + mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" + "github.com/aws/aws-sdk-go/aws" + "github.com/golang/mock/gomock" +) + +const ( + heartbeatMessageId = "heartbeatMessageId" + heartbeatHealthy = true +) + +func TestAckHeartbeatMessage(t *testing.T) { + heartbeatReceived := &ecsacs.HeartbeatMessage{ + MessageId: aws.String(heartbeatMessageId), + Healthy: aws.Bool(heartbeatHealthy), + } + + heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{ + MessageId: aws.String(heartbeatMessageId), + } + + validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected) +} + +func TestAckHeartbeatMessageWithoutMessageId(t *testing.T) { + heartbeatReceived := &ecsacs.HeartbeatMessage{ + Healthy: aws.Bool(heartbeatHealthy), + } + + heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{ + MessageId: nil, + } + + validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected) +} + +func TestAckHeartbeatMessageEmpty(t *testing.T) { + heartbeatReceived := &ecsacs.HeartbeatMessage{} + + heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{ + MessageId: nil, + } + + validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected) +} + +func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessage, heartbeatAckExpected *ecsacs.HeartbeatAckRequest) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithCancel(context.Background()) + var heartbeatAckSent *ecsacs.HeartbeatAckRequest + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) { + heartbeatAckSent = message + cancel() + }).Times(1) + + handler := newHeartbeatHandler(ctx, mockWsClient) + go handler.sendHeartbeatAck() + + handler.handleSingleHeartbeatMessage(heartbeatReceived) + + // wait till we get an ack from heartbeatAckMessageBuffer + select { + case <-ctx.Done(): + } + + if !reflect.DeepEqual(heartbeatAckExpected, heartbeatAckSent) { + t.Errorf("Message mismatch between expected and sent ack, expected: %v, sent: %v", heartbeatAckExpected, heartbeatAckSent) + } +} diff --git a/agent/acs/model/api/api-2.json b/agent/acs/model/api/api-2.json index 3036c4f3f5..3f873aafe7 100644 --- a/agent/acs/model/api/api-2.json +++ b/agent/acs/model/api/api-2.json @@ -47,7 +47,8 @@ "requestUri":"/" }, "input":{"shape":"HeartbeatMessage"}, - "documentation":"Heartbeat is a periodic message that informs the agent all is well." + "output":{"shape":"HeartbeatAckRequest"}, + "documentation":"Heartbeat is a periodic message between the Agent and ECS backend to keep the connection alive." }, "Payload":{ "name":"Payload", @@ -417,7 +418,14 @@ "HeartbeatMessage":{ "type":"structure", "members":{ - "healthy":{"shape":"Boolean"} + "healthy":{"shape":"Boolean"}, + "messageId":{"shape":"String"} + } + }, + "HeartbeatAckRequest":{ + "type":"structure", + "members":{ + "messageId":{"shape":"String"} } }, "HostVolumeProperties":{ diff --git a/agent/acs/model/ecsacs/api.go b/agent/acs/model/ecsacs/api.go index 2d45dfc88a..622d197b96 100644 --- a/agent/acs/model/ecsacs/api.go +++ b/agent/acs/model/ecsacs/api.go @@ -720,10 +720,28 @@ func (s FirelensConfiguration) GoString() string { return s.String() } +type HeartbeatAckRequest struct { + _ struct{} `type:"structure"` + + MessageId *string `locationName:"messageId" type:"string"` +} + +// String returns the string representation +func (s HeartbeatAckRequest) String() string { + return awsutil.Prettify(s) +} + +// GoString returns the string representation +func (s HeartbeatAckRequest) GoString() string { + return s.String() +} + type HeartbeatInput struct { _ struct{} `type:"structure"` Healthy *bool `locationName:"healthy" type:"boolean"` + + MessageId *string `locationName:"messageId" type:"string"` } // String returns the string representation @@ -740,6 +758,8 @@ type HeartbeatMessage struct { _ struct{} `type:"structure"` Healthy *bool `locationName:"healthy" type:"boolean"` + + MessageId *string `locationName:"messageId" type:"string"` } // String returns the string representation @@ -754,6 +774,8 @@ func (s HeartbeatMessage) GoString() string { type HeartbeatOutput struct { _ struct{} `type:"structure"` + + MessageId *string `locationName:"messageId" type:"string"` } // String returns the string representation