diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go index 19e2f270ccb..cfd8299bc2c 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go @@ -64,6 +64,7 @@ const ( // The Session.Start() method can be used to start processing messages from ACS. type Session interface { Start(context.Context) error + GetLastConnectedTime() time.Time } // session encapsulates all arguments needed to connect to ACS and to handle messages received by ACS. @@ -97,6 +98,7 @@ type session struct { disconnectTimeout time.Duration disconnectJitter time.Duration inactiveInstanceReconnectDelay time.Duration + lastConnectedTime time.Time } // NewSession creates a new Session. @@ -155,6 +157,7 @@ func NewSession(containerInstanceARN string, disconnectTimeout: wsclient.DisconnectTimeout, disconnectJitter: wsclient.DisconnectJitterMax, inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + lastConnectedTime: time.Time{}, } } @@ -245,6 +248,9 @@ func (s *session) startSessionOnce(ctx context.Context) error { } defer disconnectTimer.Stop() + // Record the timestamp of the last connection to ACS. + s.lastConnectedTime = time.Now() + // Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence // and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS. logger.Info("Connected to ACS endpoint") @@ -425,3 +431,8 @@ func formatDockerVersion(dockerVersionValue string) string { } return dockerVersionValue } + +// GetLastConnectedTime returns the timestamp that the last connection was established to ACS. +func (s *session) GetLastConnectedTime() time.Time { + return s.lastConnectedTime +} diff --git a/ecs-agent/acs/session/session.go b/ecs-agent/acs/session/session.go index 19e2f270ccb..cfd8299bc2c 100644 --- a/ecs-agent/acs/session/session.go +++ b/ecs-agent/acs/session/session.go @@ -64,6 +64,7 @@ const ( // The Session.Start() method can be used to start processing messages from ACS. type Session interface { Start(context.Context) error + GetLastConnectedTime() time.Time } // session encapsulates all arguments needed to connect to ACS and to handle messages received by ACS. @@ -97,6 +98,7 @@ type session struct { disconnectTimeout time.Duration disconnectJitter time.Duration inactiveInstanceReconnectDelay time.Duration + lastConnectedTime time.Time } // NewSession creates a new Session. @@ -155,6 +157,7 @@ func NewSession(containerInstanceARN string, disconnectTimeout: wsclient.DisconnectTimeout, disconnectJitter: wsclient.DisconnectJitterMax, inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + lastConnectedTime: time.Time{}, } } @@ -245,6 +248,9 @@ func (s *session) startSessionOnce(ctx context.Context) error { } defer disconnectTimer.Stop() + // Record the timestamp of the last connection to ACS. + s.lastConnectedTime = time.Now() + // Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence // and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS. logger.Info("Connected to ACS endpoint") @@ -425,3 +431,8 @@ func formatDockerVersion(dockerVersionValue string) string { } return dockerVersionValue } + +// GetLastConnectedTime returns the timestamp that the last connection was established to ACS. +func (s *session) GetLastConnectedTime() time.Time { + return s.lastConnectedTime +} diff --git a/ecs-agent/acs/session/session_test.go b/ecs-agent/acs/session/session_test.go index 5b6f7d48e8a..82e3a3c79aa 100644 --- a/ecs-agent/acs/session/session_test.go +++ b/ecs-agent/acs/session/session_test.go @@ -1286,6 +1286,98 @@ func TestSessionCallsAddUpdateRequestHandlers(t *testing.T) { assert.True(t, addUpdateRequestHandlersCalled) } +// TestGetLastConnectedTime tests that the Session's 'lastConnectedTime' field is updated correctly for successive +// invocations of startSessionOnce. Also tests that the Session's GetLastConnectedTime() API call works as expected. +func TestGetLastConnectedTime(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + const numInvocations = 10 + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes() + + acsSession := NewSession(testconst.ContainerInstanceARN, + testconst.ClusterARN, + discoverEndpointClient, + nil, + noopFunc, + mockClientFactory, + metricsfactory.NewNopEntryFactory(), + agentVersion, + agentGitShortHash, + dockerVersion, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + acsSession.(*session).heartbeatTimeout = 20 * time.Millisecond + acsSession.(*session).heartbeatJitter = 10 * time.Millisecond + acsSession.(*session).disconnectTimeout = 30 * time.Millisecond + acsSession.(*session).disconnectJitter = 10 * time.Millisecond + gomock.InOrder( + // When the websocket client connects to ACS for the first time, 'sendCredentials' should be set to true. + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + assert.Equal(t, true, acsSession.(*session).sendCredentials) + }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil), + // For all subsequent connections to ACS, 'sendCredentials' should be set to false. + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + assert.Equal(t, false, acsSession.(*session).sendCredentials) + }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).Times(numInvocations-1), + ) + + // The Session's lastConnectedTime field was initialized with time.Time{}, which is the default zero value for time.Time. + // At this point, since the Session has not connected to ACS yet, the Session's lastConnectedTime should still be zero. + assert.True(t, acsSession.GetLastConnectedTime().IsZero()) + + go func() { + for i := 0; i < numInvocations; i++ { + // Record the current time. + currentTime := time.Now() + // Invoke startSessionOnce() to connect to ACS. + acsSession.(*session).startSessionOnce(ctx) + // Get the timestamp recorded in Session's lastConnectedTime field. + acsSessionActualConnectedTime := acsSession.GetLastConnectedTime() + // Compare the two timestamps. + // Since the connection was started right after the first timestamp was recorded, the two timestamps should + // be very close. Allow an 1 ms to account for jitters. + assert.WithinDuration(t, currentTime, acsSessionActualConnectedTime, 1*time.Millisecond) + // Sleep for 2 ms before proceeding to the next test iteration, so that if the Session's lastConnectedTime + // field is not correctly updated, it would be caught since the allowed delta is 1 ms. + time.Sleep(2 * time.Millisecond) + } + cancel() + }() + + // Wait for context to be canceled. + select { + case <-ctx.Done(): + cancel() + } +} + func startFakeACSServer(closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { serverChan := make(chan string, 1) requestsChan := make(chan string, 1)