diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go index 32f32eedf9..7bcce84f77 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go +++ b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go @@ -125,7 +125,7 @@ func UpdateTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsM updateTaskProtectionRequestType) return } - ecsClient := factory.newTaskProtectionClient(taskRoleCredential) + ecsClient := factory.NewTaskProtectionClient(taskRoleCredential) ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) defer cancel() @@ -221,7 +221,7 @@ func GetTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsMana return } - ecsClient := factory.newTaskProtectionClient(taskRoleCredential) + ecsClient := factory.NewTaskProtectionClient(taskRoleCredential) ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) defer cancel() @@ -286,7 +286,7 @@ func GetTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsMana } // Helper function for retrieving credential from credentials manager and create ecs client -func (factory TaskProtectionClientFactory) newTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK { +func (factory TaskProtectionClientFactory) NewTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK { taskCredential := taskRoleCredential.GetIAMRoleCredentials() cfg := aws.NewConfig(). WithCredentials(awscreds.NewStaticCredentials(taskCredential.AccessKeyID, diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_mocks.go b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_mocks.go index 0950bb29b8..d1baf9d293 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_mocks.go +++ b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_mocks.go @@ -49,16 +49,16 @@ func (m *MockTaskProtectionClientFactoryInterface) EXPECT() *MockTaskProtectionC return m.recorder } -// newTaskProtectionClient mocks base method. -func (m *MockTaskProtectionClientFactoryInterface) newTaskProtectionClient(arg0 credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK { +// NewTaskProtectionClient mocks base method. +func (m *MockTaskProtectionClientFactoryInterface) NewTaskProtectionClient(arg0 credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "newTaskProtectionClient", arg0) + ret := m.ctrl.Call(m, "NewTaskProtectionClient", arg0) ret0, _ := ret[0].(api.ECSTaskProtectionSDK) return ret0 } -// newTaskProtectionClient indicates an expected call of newTaskProtectionClient. -func (mr *MockTaskProtectionClientFactoryInterfaceMockRecorder) newTaskProtectionClient(arg0 interface{}) *gomock.Call { +// NewTaskProtectionClient indicates an expected call of NewTaskProtectionClient. +func (mr *MockTaskProtectionClientFactoryInterfaceMockRecorder) NewTaskProtectionClient(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "newTaskProtectionClient", reflect.TypeOf((*MockTaskProtectionClientFactoryInterface)(nil).newTaskProtectionClient), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewTaskProtectionClient", reflect.TypeOf((*MockTaskProtectionClientFactoryInterface)(nil).NewTaskProtectionClient), arg0) } diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go index 76955f1087..f57d34a8c6 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go +++ b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go @@ -86,7 +86,7 @@ func TestGetECSClientHappyCase(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - ret := factory.newTaskProtectionClient(testIAMRoleCredentials) + ret := factory.NewTaskProtectionClient(testIAMRoleCredentials) _, ok := ret.(api.ECSTaskProtectionSDK) // Assert response @@ -405,7 +405,7 @@ func TestUpdateTaskProtectionHandler_PostCall(t *testing.T) { mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true) mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, true) - mockFactory.EXPECT().newTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient) + mockFactory.EXPECT().NewTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient) mockECSClient.EXPECT(). UpdateTaskProtectionWithContext(gomock.Any(), gomock.Any()). Return(tc.ecsResponse, tc.ecsError) @@ -644,7 +644,7 @@ func TestGetTaskProtectionHandler_PostCall(t *testing.T) { mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true) mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, true) - mockFactory.EXPECT().newTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient) + mockFactory.EXPECT().NewTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient) mockECSClient.EXPECT(). GetTaskProtectionWithContext(gomock.Any(), gomock.Any()). Return(tc.ecsResponse, tc.ecsError) diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/interface.go b/agent/handlers/agentapi/taskprotection/v1/handlers/interface.go index 7d18d8b296..9e0c890864 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/interface.go +++ b/agent/handlers/agentapi/taskprotection/v1/handlers/interface.go @@ -6,5 +6,5 @@ import ( ) type TaskProtectionClientFactoryInterface interface { - newTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK + NewTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK } diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index c09ef82616..9bf78310e7 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -54,15 +54,14 @@ func taskServerSetup(credentialsManager credentials.Manager, state dockerstate.TaskEngineState, ecsClient api.ECSClient, cluster string, - region string, statsEngine stats.Engine, steadyStateRate int, burstRate int, availabilityZone string, vpcID string, containerInstanceArn string, - apiEndpoint string, - acceptInsecureCert bool) (*http.Server, error) { + taskProtectionClientFactory agentAPITaskProtectionV1.TaskProtectionClientFactoryInterface, +) (*http.Server, error) { muxRouter := mux.NewRouter() @@ -79,7 +78,7 @@ func taskServerSetup(credentialsManager credentials.Manager, v4HandlersSetup(muxRouter, state, ecsClient, statsEngine, cluster, availabilityZone, vpcID, containerInstanceArn) - agentAPIV1HandlersSetup(muxRouter, state, credentialsManager, cluster, region, apiEndpoint, acceptInsecureCert) + agentAPIV1HandlersSetup(muxRouter, state, credentialsManager, cluster, taskProtectionClientFactory) return tmds.NewServer(auditLogger, tmds.WithHandler(muxRouter), @@ -152,10 +151,13 @@ func v4HandlersSetup(muxRouter *mux.Router, } // agentAPIV1HandlersSetup adds handlers for Agent API V1 -func agentAPIV1HandlersSetup(muxRouter *mux.Router, state dockerstate.TaskEngineState, credentialsManager credentials.Manager, cluster string, region string, endpoint string, acceptInsecureCert bool) { - factory := agentAPITaskProtectionV1.TaskProtectionClientFactory{ - Region: region, Endpoint: endpoint, AcceptInsecureCert: acceptInsecureCert, - } +func agentAPIV1HandlersSetup( + muxRouter *mux.Router, + state dockerstate.TaskEngineState, + credentialsManager credentials.Manager, + cluster string, + factory agentAPITaskProtectionV1.TaskProtectionClientFactoryInterface, +) { muxRouter. HandleFunc( agentAPITaskProtectionV1.TaskProtectionPath(), @@ -190,9 +192,12 @@ func ServeTaskHTTPEndpoint( auditLogger := audit.NewAuditLog(containerInstanceArn, cfg, logger) - server, err := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, cfg.AWSRegion, statsEngine, - cfg.TaskMetadataSteadyStateRate, cfg.TaskMetadataBurstRate, availabilityZone, vpcID, containerInstanceArn, cfg.APIEndpoint, - cfg.AcceptInsecureCert) + taskProtectionClientFactory := agentAPITaskProtectionV1.TaskProtectionClientFactory{ + Region: cfg.AWSRegion, Endpoint: cfg.APIEndpoint, AcceptInsecureCert: cfg.AcceptInsecureCert, + } + server, err := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, + statsEngine, cfg.TaskMetadataSteadyStateRate, cfg.TaskMetadataBurstRate, + availabilityZone, vpcID, containerInstanceArn, taskProtectionClientFactory) if err != nil { seelog.Criticalf("Failed to set up Task Metadata Server: %v", err) return diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index ac9376926c..a568041d92 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -37,7 +37,9 @@ import ( "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/ecs_client/model/ecs" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" + agentapihandlers "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" task_protection_v1 "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" + agentapi "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/types" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" "github.com/aws/amazon-ecs-agent/agent/stats" mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock" @@ -53,6 +55,8 @@ import ( v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" "github.com/docker/docker/api/types" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -104,9 +108,7 @@ const ( macAddress = "06:96:9a:ce:a6:ce" privateDNSName = "ip-172-31-47-69.us-west-2.compute.internal" subnetGatewayIpv4Address = "172.31.32.1/20" - region = "us-west-2" - endpoint = "ecsEndpoint" - acceptInsecureCert = true + taskCredentialsID = "taskCredentialsId" ) var ( @@ -129,33 +131,6 @@ var ( Name: associationName, Type: associationType, } - task = &apitask.Task{ - Arn: taskARN, - Associations: []apitask.Association{association}, - Family: family, - Version: version, - DesiredStatusUnsafe: apitaskstatus.TaskRunning, - KnownStatusUnsafe: apitaskstatus.TaskRunning, - NetworkMode: apitask.AWSVPCNetworkMode, - ENIs: []*apieni.ENI{ - { - IPV4Addresses: []*apieni.ENIIPV4Address{ - { - Address: eniIPv4Address, - }, - }, - MacAddress: macAddress, - PrivateDNSName: privateDNSName, - SubnetGatewayIPV4Address: subnetGatewayIpv4Address, - }, - }, - CPU: cpu, - Memory: memory, - PullStartedAtUnsafe: now, - PullStoppedAtUnsafe: now, - ExecutionStoppedAtUnsafe: now, - LaunchType: "EC2", - } pulledTask = &apitask.Task{ Arn: taskARN, Associations: []apitask.Association{pulledAssociation}, @@ -436,6 +411,38 @@ var ( }) ) +func standardTask() *apitask.Task { + task := apitask.Task{ + Arn: taskARN, + Associations: []apitask.Association{association}, + Family: family, + Version: version, + DesiredStatusUnsafe: apitaskstatus.TaskRunning, + KnownStatusUnsafe: apitaskstatus.TaskRunning, + NetworkMode: apitask.AWSVPCNetworkMode, + ENIs: []*apieni.ENI{ + { + IPV4Addresses: []*apieni.ENIIPV4Address{ + { + Address: eniIPv4Address, + }, + }, + MacAddress: macAddress, + PrivateDNSName: privateDNSName, + SubnetGatewayIPV4Address: subnetGatewayIpv4Address, + }, + }, + CPU: cpu, + Memory: memory, + PullStartedAtUnsafe: now, + PullStoppedAtUnsafe: now, + ExecutionStoppedAtUnsafe: now, + LaunchType: "EC2", + } + task.SetCredentialsID(taskCredentialsID) + return &task +} + // Returns a standard v2 task response. This getter function protects against tests mutating the // response. func expectedTaskResponse() v2.TaskResponse { @@ -564,6 +571,19 @@ func expectedV4TaskResponseNoContainers() v4.TaskResponse { return taskResponse } +func taskRoleCredentials() credentials.TaskIAMRoleCredentials { + return credentials.TaskIAMRoleCredentials{ + ARN: taskARN, + IAMRoleCredentials: credentials.IAMRoleCredentials{ + RoleArn: "roleArn", + AccessKeyID: "accessKeyID", + SecretAccessKey: "secretAccessKey", + SessionToken: "sessionToken", + Expiration: "expiration", + }, + } +} + func v4TaskResponseFromV2( v2TaskResponse v2.TaskResponse, containers []v4.ContainerResponse, @@ -763,9 +783,9 @@ func testErrorResponsesFromServer(t *testing.T, path string, expectedErrorMessag credentialsManager := mock_credentials.NewMockManager(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", "", nil, + server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, "", true) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() @@ -800,9 +820,9 @@ func getResponseForCredentialsRequest(t *testing.T, expectedStatus int, credentialsManager := mock_credentials.NewMockManager(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", "", nil, + server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, "", true) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() @@ -859,9 +879,9 @@ func TestV2ContainerStats(t *testing.T) { state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v2BaseStatsPath+"/"+containerID, nil) @@ -910,9 +930,9 @@ func TestV2TaskStats(t *testing.T) { state.EXPECT().ContainerMapByArn(taskARN).Return(containerMap, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", tc.path, nil) @@ -954,9 +974,9 @@ func TestV3TaskStats(t *testing.T) { state.EXPECT().ContainerMapByArn(taskARN).Return(containerMap, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/task/stats", nil) @@ -989,9 +1009,9 @@ func TestV3ContainerStats(t *testing.T) { state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/stats", nil) @@ -1018,11 +1038,11 @@ func TestV3ContainerAssociations(t *testing.T) { state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), - state.EXPECT().TaskByArn(taskARN).Return(task, true), + state.EXPECT().TaskByArn(taskARN).Return(standardTask(), true), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType, nil) @@ -1041,6 +1061,8 @@ func TestV3ContainerAssociation(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + task := standardTask() + state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) @@ -1050,9 +1072,9 @@ func TestV3ContainerAssociation(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) @@ -1087,9 +1109,9 @@ func TestV4TaskStats(t *testing.T) { state.EXPECT().ContainerMapByArn(taskARN).Return(containerMap, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/task/stats", nil) @@ -1122,9 +1144,9 @@ func TestV4ContainerStats(t *testing.T) { state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), statsEngine.EXPECT().ContainerDockerStats(taskARN, containerID).Return(dockerStats, &stats.NetworkStatsPerSec{}, nil), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/stats", nil) @@ -1142,6 +1164,8 @@ func TestV4ContainerAssociations(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + task := standardTask() + state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) @@ -1153,9 +1177,9 @@ func TestV4ContainerAssociations(t *testing.T) { state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType, nil) @@ -1174,6 +1198,8 @@ func TestV4ContainerAssociation(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + task := standardTask() + state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) @@ -1183,9 +1209,9 @@ func TestV4ContainerAssociation(t *testing.T) { state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) @@ -1210,9 +1236,9 @@ func TestTaskHTTPEndpoint301Redirect(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for testPath, expectedPath := range testPathsMap { @@ -1253,9 +1279,9 @@ func TestTaskHTTPEndpointErrorCode404(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1293,9 +1319,9 @@ func TestTaskHTTPEndpointErrorCode400(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1332,9 +1358,9 @@ func TestTaskHTTPEndpointErrorCode500(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1402,9 +1428,9 @@ func TestV4TaskNotFoundError404(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) state.EXPECT().TaskARNByV3EndpointID(gomock.Any()).Return("", tc.taskFound).AnyTimes() @@ -1458,9 +1484,9 @@ func TestV4Unexpected500Error(t *testing.T) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) // Initial lookups succeed @@ -1489,7 +1515,12 @@ func TestV4Unexpected500Error(t *testing.T) { // Types of TMDS responses, add more types as needed type TMDSResponse interface { - v2.ContainerResponse | v2.TaskResponse | v4.ContainerResponse | v4.TaskResponse | string + v2.ContainerResponse | + v2.TaskResponse | + v4.ContainerResponse | + v4.TaskResponse | + agentapi.TaskProtectionResponse | + string } // Represents a test case for TMDS. Supports generic TMDS response body types using type parametesrs. @@ -1500,6 +1531,11 @@ type TMDSTestCase[R TMDSResponse] struct { setStateExpectations func(state *mock_dockerstate.MockTaskEngineState) // Function to set expectations on mock ECS Client setECSClientExpectations func(ecsClient *mock_api.MockECSClient) + // Function to set expectations on mock Task Protection Client Factory + setTaskProtectionClientFactoryExpectations func( + ctrl *gomock.Controller, factory *agentapihandlers.MockTaskProtectionClientFactoryInterface) + // Function to set expectations on mock Credentials Manager + setCredentialsManagerExpectations func(credsManager *mock_credentials.MockManager) // Expected HTTP status code of the response expectedStatusCode int // Expected response body, all JSON compatible types are accepted @@ -1523,6 +1559,8 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) + credsManager := mock_credentials.NewMockManager(ctrl) + taskProtectionClientFactory := agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl) // Set expectations on mocks auditLog.EXPECT().Log(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() @@ -1530,12 +1568,18 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { if tc.setECSClientExpectations != nil { tc.setECSClientExpectations(ecsClient) } + if tc.setTaskProtectionClientFactoryExpectations != nil { + tc.setTaskProtectionClientFactoryExpectations(ctrl, taskProtectionClientFactory) + } + if tc.setCredentialsManagerExpectations != nil { + tc.setCredentialsManagerExpectations(credsManager) + } // Initialize server - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, - clusterName, region, statsEngine, + server, err := taskServerSetup(credsManager, auditLog, state, ecsClient, + clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, taskProtectionClientFactory) require.NoError(t, err) // Create the request @@ -1559,6 +1603,8 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { // Tests for v2 container metadata endpoint func TestV2ContainerMetadata(t *testing.T) { + task := standardTask() + t.Run("task not found by IP", func(t *testing.T) { testTMDSRequest(t, TMDSTestCase[string]{ path: v2BaseMetadataPath + "/" + containerID, @@ -1619,6 +1665,8 @@ func TestV2ContainerMetadata(t *testing.T) { } func TestV2TaskMetadata(t *testing.T) { + task := standardTask() + t.Run("task not found by IP", func(t *testing.T) { testTMDSRequest(t, TMDSTestCase[string]{ path: v2BaseMetadataPath, @@ -1693,6 +1741,8 @@ func TestV2TaskMetadata(t *testing.T) { } func TestV3ContainerMetadata(t *testing.T) { + task := standardTask() + t.Run("v3EndpointID invalid", func(t *testing.T) { testTMDSRequest(t, TMDSTestCase[string]{ path: v3BasePath + v3EndpointID, @@ -1784,6 +1834,8 @@ func TestV3ContainerMetadata(t *testing.T) { } func TestV3TaskMetadata(t *testing.T) { + task := standardTask() + t.Run("taskARN not found for v3EndpointID", func(t *testing.T) { testTMDSRequest(t, TMDSTestCase[string]{ path: v3BasePath + v3EndpointID + "/task", @@ -1895,6 +1947,8 @@ func TestV3TaskMetadata(t *testing.T) { } func TestV4ContainerMetadata(t *testing.T) { + task := standardTask() + t.Run("v3EndpointID is invalid", func(t *testing.T) { testTMDSRequest(t, TMDSTestCase[string]{ path: v4BasePath + v3EndpointID, @@ -2017,6 +2071,8 @@ func TestV4ContainerMetadata(t *testing.T) { } func TestV4TaskMetadata(t *testing.T) { + task := standardTask() + t.Run("taskARN not found for v3EndpointID", func(t *testing.T) { testTMDSRequest(t, TMDSTestCase[string]{ path: v4BasePath + v3EndpointID + "/task", @@ -2172,6 +2228,8 @@ func TestV4TaskMetadata(t *testing.T) { } func TestV2TaskMetadataWithTags(t *testing.T) { + task := standardTask() + containerInstanceTags := standardContainerInstanceTags() taskTags := standardTaskTags() @@ -2301,6 +2359,8 @@ func TestV2TaskMetadataWithTags(t *testing.T) { } func TestV3TaskMetadataWithTags(t *testing.T) { + task := standardTask() + containerInstanceTags := standardContainerInstanceTags() taskTags := standardTaskTags() @@ -2464,6 +2524,8 @@ func TestV3TaskMetadataWithTags(t *testing.T) { } func TestV4TaskMetadataWithTags(t *testing.T) { + task := standardTask() + containerInstanceTags := standardContainerInstanceTags() taskTags := standardTaskTags() @@ -2683,10 +2745,13 @@ func testAgentAPITaskProtectionV1Handler(t *testing.T, requestBody interface{}, ctrl := gomock.NewController(t) defer ctrl.Finish() + task := standardTask() + state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClientFactory := agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl) gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), @@ -2694,9 +2759,9 @@ func testAgentAPITaskProtectionV1Handler(t *testing.T, requestBody interface{}, ) // Set up the server - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) + containerInstanceArn, ecsClientFactory) require.NoError(t, err) // Prepare the request @@ -2717,9 +2782,238 @@ func testAgentAPITaskProtectionV1Handler(t *testing.T, requestBody interface{}, assert.NotNil(t, recorder.Body) } -// Tests that Agent API v1 GetTaskProtection handler is registered correctly -func TestAgentAPIV1GetTaskProtectionHandler(t *testing.T) { - testAgentAPITaskProtectionV1Handler(t, nil, "GET") +func TestGetTaskProtection(t *testing.T) { + path := fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID) + + // Set up some fake data + task := standardTask() + ecsInput := ecs.GetTaskProtectionInput{ + Cluster: aws.String(clusterName), + Tasks: aws.StringSlice([]string{taskARN}), + } + protectedTask := ecs.ProtectedTask{ + ProtectionEnabled: aws.Bool(true), + TaskArn: aws.String(taskARN), + } + ecsOutput := ecs.GetTaskProtectionOutput{ + ProtectedTasks: []*ecs.ProtectedTask{&protectedTask}, + } + ecsRequestID := "reqid" + ecsErrMessage := "ecs error message" + + // Helper functions to set expectation on mocks + happyStateExpectations := func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().TaskByArn(taskARN).Return(task, true), + ) + } + happyCredentialsManagerExpectations := func(credsManager *mock_credentials.MockManager) { + credsManager.EXPECT(). + GetTaskCredentials(task.GetCredentialsID()). + Return(taskRoleCredentials(), true) + } + taskProtectionClientFactoryExpectations := func(output *ecs.GetTaskProtectionOutput, err error) func( + *gomock.Controller, *task_protection_v1.MockTaskProtectionClientFactoryInterface, + ) { + return func( + ctrl *gomock.Controller, + factory *task_protection_v1.MockTaskProtectionClientFactoryInterface, + ) { + client := mock_api.NewMockECSTaskProtectionSDK(ctrl) + client.EXPECT().GetTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err) + factory.EXPECT().NewTaskProtectionClient(taskRoleCredentials()).Return(client) + } + } + + // Test cases start here + t.Run("task ARN not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return("", false), + ) + }, + expectedStatusCode: http.StatusNotFound, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Error: &agentapi.ErrorResponse{ + Code: ecs.ErrCodeResourceNotFoundException, + Message: "Invalid request: no task was found", + }, + }, + }) + }) + t.Run("task not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().TaskByArn(taskARN).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Error: &agentapi.ErrorResponse{ + Code: ecs.ErrCodeServerException, + Message: "Failed to find a task for the request", + }, + }, + }) + }) + t.Run("task credentials not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: func(credsManager *mock_credentials.MockManager) { + credsManager. + EXPECT().GetTaskCredentials(taskCredentialsID). + Return(credentials.TaskIAMRoleCredentials{}, false) + }, + expectedStatusCode: http.StatusForbidden, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Error: &agentapi.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeAccessDeniedException, + Message: "Invalid Request: no task IAM role credentials available for task", + }, + }, + }) + }) + t.Run("ecs call server exception", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: happyCredentialsManagerExpectations, + setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( + nil, + awserr.NewRequestFailure( + awserr.New(ecs.ErrCodeServerException, ecsErrMessage, nil), + http.StatusInternalServerError, + ecsRequestID, + ), + ), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: agentapi.TaskProtectionResponse{ + RequestID: &ecsRequestID, + Error: &agentapi.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeServerException, + Message: ecsErrMessage, + }, + }, + }) + }) + t.Run("ecs call access denied exception", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: happyCredentialsManagerExpectations, + setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( + nil, + awserr.NewRequestFailure( + awserr.New(ecs.ErrCodeAccessDeniedException, ecsErrMessage, nil), + http.StatusBadRequest, + ecsRequestID, + ), + ), + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: agentapi.TaskProtectionResponse{ + RequestID: &ecsRequestID, + Error: &agentapi.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeAccessDeniedException, + Message: ecsErrMessage, + }, + }, + }) + }) + t.Run("ecs call non-request-failure aws error", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: happyCredentialsManagerExpectations, + setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( + nil, + awserr.New(ecs.ErrCodeInvalidParameterException, ecsErrMessage, nil)), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Error: &agentapi.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeInvalidParameterException, + Message: ecsErrMessage, + }, + }, + }) + }) + t.Run("agent timeout", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: happyCredentialsManagerExpectations, + setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( + nil, awserr.New(request.CanceledErrorCode, "request cancelled", nil)), + expectedStatusCode: http.StatusGatewayTimeout, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Error: &agentapi.ErrorResponse{ + Arn: taskARN, + Code: request.CanceledErrorCode, + Message: "Timed out calling ECS Task Protection API", + }, + }, + }) + }) + t.Run("non-aws error", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: happyCredentialsManagerExpectations, + setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( + nil, errors.New("some error")), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Error: &agentapi.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeServerException, + Message: "some error", + }, + }, + }) + }) + t.Run("ecs failure", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: happyCredentialsManagerExpectations, + setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( + &ecs.GetTaskProtectionOutput{ + Failures: []*ecs.Failure{{ + Arn: aws.String(taskARN), + Reason: aws.String("ecs failure"), + }}, + }, nil), + expectedStatusCode: http.StatusOK, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Failure: &ecs.Failure{ + Arn: aws.String(taskARN), + Reason: aws.String("ecs failure"), + }, + }, + }) + }) + t.Run("happy case", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + path: path, + setStateExpectations: happyStateExpectations, + setCredentialsManagerExpectations: happyCredentialsManagerExpectations, + setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil), + expectedStatusCode: http.StatusOK, + expectedResponseBody: agentapi.TaskProtectionResponse{ + Protection: &protectedTask, + }, + }) + }) } // Tests that Agent API v1 UpdateTaskProtection handler is registered correctly