diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index fdad96c409c..7a142f4dc3b 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -40,6 +40,7 @@ import ( task_protection_v1 "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" agentapi "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/types" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" + v4stats "github.com/aws/amazon-ecs-agent/agent/handlers/v4" "github.com/aws/amazon-ecs-agent/agent/stats" mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" @@ -863,38 +864,6 @@ func parseResponseBody(body *bytes.Buffer) (*credentials.IAMRoleCredentials, err return &creds, nil } -func TestV2ContainerStats(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - dockerStats := &types.StatsJSON{} - dockerStats.NumProcs = 2 - gomock.InOrder( - state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true), - statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), - ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) - require.NoError(t, err) - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v2BaseStatsPath+"/"+containerID, nil) - req.RemoteAddr = remoteIP + ":" + remotePort - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - var statsFromResult *types.StatsJSON - err = json.Unmarshal(res, &statsFromResult) - assert.NoError(t, err) - assert.Equal(t, dockerStats.NumProcs, statsFromResult.NumProcs) -} - func TestV2TaskStats(t *testing.T) { testCases := []struct { path string @@ -991,39 +960,6 @@ func TestV3TaskStats(t *testing.T) { assert.Equal(t, dockerStats.NumProcs, containerStats.NumProcs) } -func TestV3ContainerStats(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - dockerStats := &types.StatsJSON{} - dockerStats.NumProcs = 2 - - gomock.InOrder( - state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), - state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), - ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) - require.NoError(t, err) - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/stats", nil) - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - var statsFromResult *types.StatsJSON - err = json.Unmarshal(res, &statsFromResult) - assert.NoError(t, err) - assert.Equal(t, dockerStats.NumProcs, statsFromResult.NumProcs) -} - func TestV3ContainerAssociations(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1126,39 +1062,6 @@ func TestV4TaskStats(t *testing.T) { assert.Equal(t, dockerStats.NumProcs, containerStats.NumProcs) } -func TestV4ContainerStats(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - dockerStats := &types.StatsJSON{} - dockerStats.NumProcs = 2 - - gomock.InOrder( - state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), - state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), - ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) - require.NoError(t, err) - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/stats", nil) - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - var statsFromResult *types.StatsJSON - err = json.Unmarshal(res, &statsFromResult) - assert.NoError(t, err) - assert.Equal(t, dockerStats.NumProcs, statsFromResult.NumProcs) -} - func TestV4ContainerAssociations(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1519,6 +1422,8 @@ type TMDSResponse interface { v4.ContainerResponse | v4.TaskResponse | agentapi.TaskProtectionResponse | + types.StatsJSON | + v4stats.StatsResponse | string } @@ -1532,6 +1437,8 @@ type TMDSTestCase[R TMDSResponse] struct { requestBody interface{} // Function to set expectations on mock task engine state setStateExpectations func(state *mock_dockerstate.MockTaskEngineState) + // Function to set expectations on mock stats engine + setStatsEngineExpectations func(engine *mock_stats.MockEngine) // Function to set expectations on mock ECS Client setECSClientExpectations func(ecsClient *mock_api.MockECSClient) // Function to set expectations on mock Task Protection Client Factory @@ -1570,6 +1477,9 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { if tc.setStateExpectations != nil { tc.setStateExpectations(state) } + if tc.setStatsEngineExpectations != nil { + tc.setStatsEngineExpectations(statsEngine) + } if tc.setECSClientExpectations != nil { tc.setECSClientExpectations(ecsClient) } @@ -2753,6 +2663,189 @@ func TestV4TaskMetadataWithTags(t *testing.T) { }) } +func TestV2ContainerStats(t *testing.T) { + path := v2BaseStatsPath + "/" + containerID + t.Run("task not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + state.EXPECT().GetTaskByIPAddress(remoteIP).Return("", false) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf( + "Unable to get task arn from request: unable to associate '%s' with task", remoteIP), + }) + }) + t.Run("stats not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true) + }, + setStatsEngineExpectations: func(engine *mock_stats.MockEngine) { + engine.EXPECT().ContainerDockerStats(taskARN, containerID). + Return(nil, nil, errors.New("some error")) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf("Unable to get container stats for: %s", containerID), + }) + }) + t.Run("happy case", func(t *testing.T) { + dockerStats := types.StatsJSON{Stats: types.Stats{NumProcs: 2}} + testTMDSRequest(t, TMDSTestCase[types.StatsJSON]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true) + }, + setStatsEngineExpectations: func(engine *mock_stats.MockEngine) { + engine.EXPECT().ContainerDockerStats(taskARN, containerID). + Return(&dockerStats, &stats.NetworkStatsPerSec{}, nil) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: dockerStats, + }) + }) +} + +func TestV3ContainerStats(t *testing.T) { + path := v3BasePath + v3EndpointID + "/stats" + t.Run("task not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return("", false) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf( + "V3 container stats handler: unable to get task arn from request: unable to get task Arn from v3 endpoint ID: %s", + v3EndpointID), + }) + }) + t.Run("Docker ID not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return("", false), + ) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf( + "V3 container stats handler: unable to get container ID from request: unable to get docker ID from v3 endpoint ID: %s", + v3EndpointID), + }) + }) + t.Run("stats not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + ) + }, + setStatsEngineExpectations: func(engine *mock_stats.MockEngine) { + engine.EXPECT().ContainerDockerStats(taskARN, containerID). + Return(nil, nil, errors.New("some error")) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf("Unable to get container stats for: %s", containerID), + }) + }) + t.Run("happy case", func(t *testing.T) { + dockerStats := types.StatsJSON{Stats: types.Stats{NumProcs: 2}} + testTMDSRequest(t, TMDSTestCase[types.StatsJSON]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + ) + }, + setStatsEngineExpectations: func(engine *mock_stats.MockEngine) { + engine.EXPECT().ContainerDockerStats(taskARN, containerID). + Return(&dockerStats, &stats.NetworkStatsPerSec{}, nil) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: dockerStats, + }) + }) +} + +func TestV4ContainerStats(t *testing.T) { + path := v4BasePath + v3EndpointID + "/stats" + t.Run("task not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return("", false) + }, + expectedStatusCode: http.StatusNotFound, + expectedResponseBody: fmt.Sprintf( + "V4 container handler: unable to get task arn from request: unable to get task Arn from v3 endpoint ID: %s", + v3EndpointID), + }) + }) + t.Run("Docker ID not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return("", false), + ) + }, + expectedStatusCode: http.StatusNotFound, + expectedResponseBody: fmt.Sprintf( + "V4 container stats handler: unable to get container ID from request: unable to get docker ID from v3 endpoint ID: %s", + v3EndpointID), + }) + }) + t.Run("stats not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + ) + }, + setStatsEngineExpectations: func(engine *mock_stats.MockEngine) { + engine.EXPECT().ContainerDockerStats(taskARN, containerID). + Return(nil, nil, errors.New("some error")) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: "Unable to get container stats for: " + containerID, + }) + }) + t.Run("happy case", func(t *testing.T) { + dockerStats := types.StatsJSON{Stats: types.Stats{NumProcs: 2}} + networkStats := stats.NetworkStatsPerSec{ + RxBytesPerSecond: 52, + TxBytesPerSecond: 84, + } + testTMDSRequest(t, TMDSTestCase[v4stats.StatsResponse]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + ) + }, + setStatsEngineExpectations: func(engine *mock_stats.MockEngine) { + engine.EXPECT().ContainerDockerStats(taskARN, containerID). + Return(&dockerStats, &networkStats, nil) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: v4stats.StatsResponse{ + StatsJSON: &dockerStats, + Network_rate_stats: &networkStats, + }, + }) + }) +} + func TestGetTaskProtection(t *testing.T) { path := fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID)