Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for gMSA s3 test #3886

Merged
merged 6 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions agent/s3/factory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (

type S3ClientCreator interface {
NewS3ManagerClient(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3ManagerClient, error)
NewS3Client(region string, creds credentials.IAMRoleCredentials) s3client.S3Client
NewS3Client(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3Client, error)
}

// NewS3ClientCreator provide 2 implementations
Expand Down Expand Up @@ -65,15 +65,21 @@ func (*s3ClientCreator) NewS3ManagerClient(bucket, region string,
}

// NewS3Client returns a new S3 client to support s3 operations which are not provided by s3manager.
func (*s3ClientCreator) NewS3Client(region string,
creds credentials.IAMRoleCredentials) s3client.S3Client {
func (*s3ClientCreator) NewS3Client(bucket, region string,
creds credentials.IAMRoleCredentials) (s3client.S3Client, error) {
cfg := aws.NewConfig().
WithHTTPClient(httpclient.New(roundtripTimeout, false)).
WithCredentials(
awscreds.NewStaticCredentials(creds.AccessKeyID, creds.SecretAccessKey,
creds.SessionToken)).WithRegion(region)
sess := session.Must(session.NewSession(cfg))
return s3.New(sess)
svc := s3.New(sess)
bucketRegion, err := getRegionFromBucket(svc, bucket)
if err != nil {
return nil, err
}
sessWithRegion := session.Must(session.NewSession(cfg.WithRegion(bucketRegion)))
return s3.New(sessWithRegion), nil
}
func getRegionFromBucket(svc *s3.S3, bucket string) (string, error) {
input := &s3.GetBucketLocationInput{
Expand Down
11 changes: 6 additions & 5 deletions agent/s3/factory/mocks/factory_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialS
return
}

s3Client := cs.s3ClientCreator.NewS3Client(cs.region, iamCredentials)
s3Client, err := cs.s3ClientCreator.NewS3Client(bucket, cs.region, iamCredentials)
as14692 marked this conversation as resolved.
Show resolved Hide resolved

credSpecJsonStringUnformatted, err := s3.GetObject(bucket, key, s3Client)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValue(t *testing.T) {
Body: io.NopCloser(strings.NewReader(testData)),
}
gomock.InOrder(
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1),
)

Expand Down Expand Up @@ -439,7 +439,7 @@ func TestHandleS3DomainlessCredentialSpecFileGetS3SecretValue(t *testing.T) {
Body: io.NopCloser(strings.NewReader(testData)),
}
gomock.InOrder(
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1),
)

Expand Down Expand Up @@ -501,7 +501,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValueErr(t *testing.T) {
}, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning)

gomock.InOrder(
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(nil, errors.New("test-error")).Times(1),
)

Expand Down