diff --git a/agent/s3/factory/factory.go b/agent/s3/factory/factory.go index b5111439514..716c61ee215 100644 --- a/agent/s3/factory/factory.go +++ b/agent/s3/factory/factory.go @@ -34,7 +34,7 @@ const ( type S3ClientCreator interface { NewS3ManagerClient(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3ManagerClient, error) - NewS3Client(region string, creds credentials.IAMRoleCredentials) s3client.S3Client + NewS3Client(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3Client, error) } // NewS3ClientCreator provide 2 implementations @@ -65,15 +65,21 @@ func (*s3ClientCreator) NewS3ManagerClient(bucket, region string, } // NewS3Client returns a new S3 client to support s3 operations which are not provided by s3manager. -func (*s3ClientCreator) NewS3Client(region string, - creds credentials.IAMRoleCredentials) s3client.S3Client { +func (*s3ClientCreator) NewS3Client(bucket, region string, + creds credentials.IAMRoleCredentials) (s3client.S3Client, error) { cfg := aws.NewConfig(). WithHTTPClient(httpclient.New(roundtripTimeout, false)). WithCredentials( awscreds.NewStaticCredentials(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)).WithRegion(region) sess := session.Must(session.NewSession(cfg)) - return s3.New(sess) + svc := s3.New(sess) + bucketRegion, err := getRegionFromBucket(svc, bucket) + if err != nil { + return nil, err + } + sessWithRegion := session.Must(session.NewSession(cfg.WithRegion(bucketRegion))) + return s3.New(sessWithRegion), nil } func getRegionFromBucket(svc *s3.S3, bucket string) (string, error) { input := &s3.GetBucketLocationInput{ diff --git a/agent/s3/factory/mocks/factory_mocks.go b/agent/s3/factory/mocks/factory_mocks.go index 0cd1e64aa35..29e7af44089 100644 --- a/agent/s3/factory/mocks/factory_mocks.go +++ b/agent/s3/factory/mocks/factory_mocks.go @@ -50,17 +50,18 @@ func (m *MockS3ClientCreator) EXPECT() *MockS3ClientCreatorMockRecorder { } // NewS3Client mocks base method. -func (m *MockS3ClientCreator) NewS3Client(arg0 string, arg1 credentials.IAMRoleCredentials) s3.S3Client { +func (m *MockS3ClientCreator) NewS3Client(arg0, arg1 string, arg2 credentials.IAMRoleCredentials) (s3.S3Client, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewS3Client", arg0, arg1) + ret := m.ctrl.Call(m, "NewS3Client", arg0, arg1, arg2) ret0, _ := ret[0].(s3.S3Client) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // NewS3Client indicates an expected call of NewS3Client. -func (mr *MockS3ClientCreatorMockRecorder) NewS3Client(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockS3ClientCreatorMockRecorder) NewS3Client(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewS3Client", reflect.TypeOf((*MockS3ClientCreator)(nil).NewS3Client), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewS3Client", reflect.TypeOf((*MockS3ClientCreator)(nil).NewS3Client), arg0, arg1, arg2) } // NewS3ManagerClient mocks base method. diff --git a/agent/taskresource/credentialspec/credentialspec_linux.go b/agent/taskresource/credentialspec/credentialspec_linux.go index fc293312a6c..f2c8a305391 100644 --- a/agent/taskresource/credentialspec/credentialspec_linux.go +++ b/agent/taskresource/credentialspec/credentialspec_linux.go @@ -435,7 +435,12 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialS return } - s3Client := cs.s3ClientCreator.NewS3Client(cs.region, iamCredentials) + s3Client, err := cs.s3ClientCreator.NewS3Client(bucket, cs.region, iamCredentials) + if err != nil { + cs.setTerminalReason(err.Error()) + errorEvents <- err + return + } credSpecJsonStringUnformatted, err := s3.GetObject(bucket, key, s3Client) diff --git a/agent/taskresource/credentialspec/credentialspec_linux_test.go b/agent/taskresource/credentialspec/credentialspec_linux_test.go index dbe11e5d2b2..a7931a182c6 100644 --- a/agent/taskresource/credentialspec/credentialspec_linux_test.go +++ b/agent/taskresource/credentialspec/credentialspec_linux_test.go @@ -374,7 +374,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValue(t *testing.T) { Body: io.NopCloser(strings.NewReader(testData)), } gomock.InOrder( - s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client), + s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1), ) @@ -439,7 +439,7 @@ func TestHandleS3DomainlessCredentialSpecFileGetS3SecretValue(t *testing.T) { Body: io.NopCloser(strings.NewReader(testData)), } gomock.InOrder( - s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client), + s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1), ) @@ -501,7 +501,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValueErr(t *testing.T) { }, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning) gomock.InOrder( - s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client), + s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil), mockS3Client.EXPECT().GetObject(gomock.Any()).Return(nil, errors.New("test-error")).Times(1), )