Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for GetTaskProtection API to high-level TMDS tests #3739

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func UpdateTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsM
updateTaskProtectionRequestType)
return
}
ecsClient := factory.newTaskProtectionClient(taskRoleCredential)
ecsClient := factory.NewTaskProtectionClient(taskRoleCredential)

ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout)
defer cancel()
Expand Down Expand Up @@ -221,7 +221,7 @@ func GetTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsMana
return
}

ecsClient := factory.newTaskProtectionClient(taskRoleCredential)
ecsClient := factory.NewTaskProtectionClient(taskRoleCredential)

ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout)
defer cancel()
Expand Down Expand Up @@ -286,7 +286,7 @@ func GetTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsMana
}

// Helper function for retrieving credential from credentials manager and create ecs client
func (factory TaskProtectionClientFactory) newTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK {
func (factory TaskProtectionClientFactory) NewTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK {
taskCredential := taskRoleCredential.GetIAMRoleCredentials()
cfg := aws.NewConfig().
WithCredentials(awscreds.NewStaticCredentials(taskCredential.AccessKeyID,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestGetECSClientHappyCase(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ret := factory.newTaskProtectionClient(testIAMRoleCredentials)
ret := factory.NewTaskProtectionClient(testIAMRoleCredentials)
_, ok := ret.(api.ECSTaskProtectionSDK)

// Assert response
Expand Down Expand Up @@ -405,7 +405,7 @@ func TestUpdateTaskProtectionHandler_PostCall(t *testing.T) {
mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true)
mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true)
mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, true)
mockFactory.EXPECT().newTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient)
mockFactory.EXPECT().NewTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient)
mockECSClient.EXPECT().
UpdateTaskProtectionWithContext(gomock.Any(), gomock.Any()).
Return(tc.ecsResponse, tc.ecsError)
Expand Down Expand Up @@ -644,7 +644,7 @@ func TestGetTaskProtectionHandler_PostCall(t *testing.T) {
mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true)
mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true)
mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, true)
mockFactory.EXPECT().newTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient)
mockFactory.EXPECT().NewTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient)
mockECSClient.EXPECT().
GetTaskProtectionWithContext(gomock.Any(), gomock.Any()).
Return(tc.ecsResponse, tc.ecsError)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ import (
)

type TaskProtectionClientFactoryInterface interface {
newTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK
NewTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK
}
27 changes: 16 additions & 11 deletions agent/handlers/task_server_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,14 @@ func taskServerSetup(credentialsManager credentials.Manager,
state dockerstate.TaskEngineState,
ecsClient api.ECSClient,
cluster string,
region string,
statsEngine stats.Engine,
steadyStateRate int,
burstRate int,
availabilityZone string,
vpcID string,
containerInstanceArn string,
apiEndpoint string,
acceptInsecureCert bool) (*http.Server, error) {
taskProtectionClientFactory agentAPITaskProtectionV1.TaskProtectionClientFactoryInterface,
) (*http.Server, error) {

muxRouter := mux.NewRouter()

Expand All @@ -79,7 +78,7 @@ func taskServerSetup(credentialsManager credentials.Manager,

v4HandlersSetup(muxRouter, state, ecsClient, statsEngine, cluster, availabilityZone, vpcID, containerInstanceArn)

agentAPIV1HandlersSetup(muxRouter, state, credentialsManager, cluster, region, apiEndpoint, acceptInsecureCert)
agentAPIV1HandlersSetup(muxRouter, state, credentialsManager, cluster, taskProtectionClientFactory)

return tmds.NewServer(auditLogger,
tmds.WithHandler(muxRouter),
Expand Down Expand Up @@ -152,10 +151,13 @@ func v4HandlersSetup(muxRouter *mux.Router,
}

// agentAPIV1HandlersSetup adds handlers for Agent API V1
func agentAPIV1HandlersSetup(muxRouter *mux.Router, state dockerstate.TaskEngineState, credentialsManager credentials.Manager, cluster string, region string, endpoint string, acceptInsecureCert bool) {
factory := agentAPITaskProtectionV1.TaskProtectionClientFactory{
Region: region, Endpoint: endpoint, AcceptInsecureCert: acceptInsecureCert,
}
func agentAPIV1HandlersSetup(
muxRouter *mux.Router,
state dockerstate.TaskEngineState,
credentialsManager credentials.Manager,
cluster string,
factory agentAPITaskProtectionV1.TaskProtectionClientFactoryInterface,
) {
muxRouter.
HandleFunc(
agentAPITaskProtectionV1.TaskProtectionPath(),
Expand Down Expand Up @@ -190,9 +192,12 @@ func ServeTaskHTTPEndpoint(

auditLogger := audit.NewAuditLog(containerInstanceArn, cfg, logger)

server, err := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, cfg.AWSRegion, statsEngine,
cfg.TaskMetadataSteadyStateRate, cfg.TaskMetadataBurstRate, availabilityZone, vpcID, containerInstanceArn, cfg.APIEndpoint,
cfg.AcceptInsecureCert)
taskProtectionClientFactory := agentAPITaskProtectionV1.TaskProtectionClientFactory{
Region: cfg.AWSRegion, Endpoint: cfg.APIEndpoint, AcceptInsecureCert: cfg.AcceptInsecureCert,
}
server, err := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster,
statsEngine, cfg.TaskMetadataSteadyStateRate, cfg.TaskMetadataBurstRate,
availabilityZone, vpcID, containerInstanceArn, taskProtectionClientFactory)
if err != nil {
seelog.Criticalf("Failed to set up Task Metadata Server: %v", err)
return
Expand Down
Loading