Skip to content

Commit

Permalink
feat: Use AWS version 2 to get the AWS credentials for AWS MSK
Browse files Browse the repository at this point in the history
  • Loading branch information
Neurostep committed Apr 22, 2024
1 parent 18eabbe commit 648a4c0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 32 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/config v1.27.11
github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6
github.com/getsentry/sentry-go v0.12.0
github.com/go-kit/kit v0.9.0
github.com/google/uuid v1.6.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc
github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
Expand Down
117 changes: 85 additions & 32 deletions pkg/pubsub/kafka/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/twmb/franz-go/pkg/kgo"

"github.com/aws/aws-sdk-go-v2/aws"
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"

awssasl "github.com/twmb/franz-go/pkg/sasl/aws"
"github.com/twmb/franz-go/pkg/sasl/plain"

Expand All @@ -24,10 +28,14 @@ type Config struct {
// Kafka configuration provided by go-sdk
KafkaConfig pubsub.Kafka
// AWS session reference, it will be used in case AWS MSK IAM authentication mechanism is used
//
// Deprecated: Use AwsConfig instead
AwsSession *session.Session
// MsgHandler is a function that will be called when a message is received
MsgHandler MsgHandler
Logger sdklogger.Logger
// AWS configuration reference, it will be used in case AWS MSK IAM authentication mechanism is used
AwsConfig *aws.Config
Logger sdklogger.Logger
}

const tlsConnectionTimeout = 10 * time.Second
Expand All @@ -43,7 +51,7 @@ func newConfig(c Config, opts ...kgo.Opt) ([]kgo.Opt, error) {
case pubsub.Plain:
options = append(options, getPlainSaslOption(c.KafkaConfig.SASL))
case pubsub.AWSMskIam:
options = append(options, getAwsMskIamSaslOption(c.KafkaConfig.SASL.AWSMskIam, c.AwsSession))
options = append(options, getAwsMskIamSaslOption(c.KafkaConfig.SASL.AWSMskIam, c.AwsSession, c.AwsConfig))
}
}

Expand Down Expand Up @@ -101,7 +109,7 @@ func getPlainSaslOption(saslConf pubsub.SASL) kgo.Opt {
}.AsMechanism())
}

func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session) kgo.Opt {
func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session, awsCfg *aws.Config) kgo.Opt {
var opt kgo.Opt

// no AWS session provided
Expand All @@ -115,40 +123,85 @@ func getAwsMskIamSaslOption(iamConf pubsub.SASLAwsMskIam, s *session.Session) kg
} else {
opt = kgo.SASL(
awssasl.ManagedStreamingIAM(func(ctx context.Context) (awssasl.Auth, error) {
// If assumable role is not provided, we try to get credentials from the provided AWS session
if iamConf.AssumableRole == "" {
val, err := s.Config.Credentials.Get()
if err != nil {
return awssasl.Auth{}, err
}

return awssasl.Auth{
AccessKey: val.AccessKeyID,
SecretKey: val.SecretAccessKey,
SessionToken: val.SessionToken,
UserAgent: iamConf.UserAgent,
}, nil
if s != nil {
return getAwsSaslAuthFromSession(iamConf, s)
}

svc := sts.New(s)

res, stsErr := svc.AssumeRole(&sts.AssumeRoleInput{
RoleArn: &iamConf.AssumableRole,
RoleSessionName: &iamConf.SessionName,
})
if stsErr != nil {
return awssasl.Auth{}, stsErr
}

return awssasl.Auth{
AccessKey: *res.Credentials.AccessKeyId,
SecretKey: *res.Credentials.SecretAccessKey,
SessionToken: *res.Credentials.SessionToken,
UserAgent: iamConf.UserAgent,
}, nil
return getAwsSaslAuthFromConfig(ctx, iamConf, awsCfg)
}),
)
}

return opt
}

func getAwsSaslAuthFromSession(iamConf pubsub.SASLAwsMskIam, s *session.Session) (awssasl.Auth, error) {
// If assumable role is not provided, we try to get credentials from the provided AWS session
if iamConf.AssumableRole == "" {
val, err := s.Config.Credentials.Get()
if err != nil {
return awssasl.Auth{}, err
}

return awssasl.Auth{
AccessKey: val.AccessKeyID,
SecretKey: val.SecretAccessKey,
SessionToken: val.SessionToken,
UserAgent: iamConf.UserAgent,
}, nil
}

svc := sts.New(s)

res, stsErr := svc.AssumeRole(&sts.AssumeRoleInput{
RoleArn: &iamConf.AssumableRole,
RoleSessionName: &iamConf.SessionName,
})
if stsErr != nil {
return awssasl.Auth{}, stsErr
}

return awssasl.Auth{
AccessKey: *res.Credentials.AccessKeyId,
SecretKey: *res.Credentials.SecretAccessKey,
SessionToken: *res.Credentials.SessionToken,
UserAgent: iamConf.UserAgent,
}, nil
}

func getAwsSaslAuthFromConfig(
ctx context.Context,
iamConf pubsub.SASLAwsMskIam,
awsCfg *aws.Config) (awssasl.Auth, error) {
// If assumable role is not provided, we try to get credentials from the provided AWS config
if iamConf.AssumableRole == "" {
val, err := awsCfg.Credentials.Retrieve(ctx)
if err != nil {
return awssasl.Auth{}, err
}

return awssasl.Auth{
AccessKey: val.AccessKeyID,
SecretKey: val.SecretAccessKey,
SessionToken: val.SessionToken,
UserAgent: iamConf.UserAgent,
}, nil
}

client := stsv2.NewFromConfig(*awsCfg)

res, stsErr := client.AssumeRole(ctx, &stsv2.AssumeRoleInput{
RoleArn: &iamConf.AssumableRole,
RoleSessionName: &iamConf.SessionName,
})
if stsErr != nil {
return awssasl.Auth{}, stsErr
}

return awssasl.Auth{
AccessKey: *res.Credentials.AccessKeyId,
SecretKey: *res.Credentials.SecretAccessKey,
SessionToken: *res.Credentials.SessionToken,
UserAgent: iamConf.UserAgent,
}, nil
}

0 comments on commit 648a4c0

Please sign in to comment.