From 978fdcee975ceb48811252c4df0144cda947b23c Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Fri, 13 Dec 2024 13:44:17 -0800 Subject: [PATCH] Migrate dynamodb engine to AWS SDK v2 This migrates the Database Service engine for DynamoDB to use AWS SDK v2. FIPS endpoint resolution has also been enabled. Finally, the engine will now resolve to the AWS-account-based endpoint for DynamoDB operations in supported regions. --- go.mod | 1 + go.sum | 2 + integrations/terraform/go.sum | 2 + lib/srv/db/dynamodb/engine.go | 182 +++++++++++++++++++---------- lib/srv/db/dynamodb/engine_test.go | 67 ++++++++--- lib/srv/db/dynamodb/test.go | 38 +++--- lib/srv/db/dynamodb_test.go | 22 ++-- 7 files changed, 204 insertions(+), 110 deletions(-) diff --git a/go.mod b/go.mod index fa63fe6eb8695..3778726f94195 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1 github.com/aws/aws-sdk-go-v2/service/athena v1.49.0 + github.com/aws/aws-sdk-go-v2/service/dax v1.23.7 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0 github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.8 github.com/aws/aws-sdk-go-v2/service/ec2 v1.195.0 diff --git a/go.sum b/go.sum index 6a73eda617b84..aadfbf4b65238 100644 --- a/go.sum +++ b/go.sum @@ -883,6 +883,8 @@ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1 h1:8EwNbY+A/ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1/go.mod h1:2mMP2R86zLPAUz0TpJdsKW8XawHgs9Nk97fYJomO3o8= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0 h1:D+iatX9gV6gCuNd6BnUkfwfZJw/cXlEk+LwwDdSMdtw= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0/go.mod h1:27ljwDsnZvfrZKsLzWD4WFjI4OZutEFIjvVtYfj9gHc= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7 h1:hZg1sHhWXGZShzHGpwcaOT8HZfx26kkbRDNZgZda4xI= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7/go.mod h1:fYBjETTq8hZfirBEgXM1xIMy+tvCGYZTeWpjeKKp0bU= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0 h1:isKhHsjpQR3CypQJ4G1g8QWx7zNpiC/xKw1zjgJYVno= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0/go.mod h1:xDvUyIkwBwNtVZJdHEwAuhFly3mezwdEWkbJ5oNYwIw= github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.8 h1:ntqHwZb+ZyVz0CFYUG0sQ02KMMJh+iXeV3bXoba+s4A= diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 1b7cf7ecfecec..d16422a0120fe 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -812,6 +812,8 @@ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1 h1:8EwNbY+A/ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1/go.mod h1:2mMP2R86zLPAUz0TpJdsKW8XawHgs9Nk97fYJomO3o8= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0 h1:D+iatX9gV6gCuNd6BnUkfwfZJw/cXlEk+LwwDdSMdtw= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0/go.mod h1:27ljwDsnZvfrZKsLzWD4WFjI4OZutEFIjvVtYfj9gHc= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7 h1:hZg1sHhWXGZShzHGpwcaOT8HZfx26kkbRDNZgZda4xI= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7/go.mod h1:fYBjETTq8hZfirBEgXM1xIMy+tvCGYZTeWpjeKKp0bU= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0 h1:isKhHsjpQR3CypQJ4G1g8QWx7zNpiC/xKw1zjgJYVno= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0/go.mod h1:xDvUyIkwBwNtVZJdHEwAuhFly3mezwdEWkbJ5oNYwIw= github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.8 h1:ntqHwZb+ZyVz0CFYUG0sQ02KMMJh+iXeV3bXoba+s4A= diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index 940b0b315a724..d877741dc628b 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -30,10 +30,10 @@ import ( "strconv" "strings" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/service/dax" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodbstreams" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dax" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams" "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" @@ -43,6 +43,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/utils" @@ -54,6 +55,7 @@ func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, RoundTrippers: make(map[string]http.RoundTripper), + UseFIPS: modules.GetModules().IsBoringBinary(), } } @@ -71,6 +73,8 @@ type Engine struct { RoundTrippers map[string]http.RoundTripper // CredentialsGetter is used to obtain STS credentials. CredentialsGetter libaws.CredentialsGetter + // UseFIPS will ensure FIPS endpoint resolution. + UseFIPS bool } var _ common.Engine = (*Engine)(nil) @@ -194,7 +198,7 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws. // emit an audit event regardless of failure, but using the resolved endpoint. var responseStatusCode uint32 defer func() { - e.emitAuditEvent(req, re.URL, responseStatusCode, err) + e.emitAuditEvent(req, re.URL.String(), responseStatusCode, err) }() // try to read, close, and replace the incoming request body. @@ -319,8 +323,8 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er } // getRoundTripper makes an HTTP round tripper with TLS config based on the given URL. -func (e *Engine) getRoundTripper(ctx context.Context, URL string) (http.RoundTripper, error) { - if rt, ok := e.RoundTrippers[URL]; ok { +func (e *Engine) getRoundTripper(ctx context.Context, u *url.URL) (http.RoundTripper, error) { + if rt, ok := e.RoundTrippers[u.String()]; ok { return rt, nil } tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser) @@ -329,55 +333,136 @@ func (e *Engine) getRoundTripper(ctx context.Context, URL string) (http.RoundTri } // We need to set the ServerName here because the AWS endpoint service prefix is not known in advance, // and the TLS config we got does not set it. - host, err := getURLHostname(URL) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConfig.ServerName = host + tlsConfig.ServerName = u.Hostname() out, err := defaults.Transport() if err != nil { return nil, trace.Wrap(err) } out.TLSClientConfig = tlsConfig - e.RoundTrippers[URL] = out + e.RoundTrippers[u.String()] = out return out, nil } -// resolveEndpoint returns a resolved endpoint for either the configured URI or the AWS target service and region. -func (e *Engine) resolveEndpoint(req *http.Request) (*endpoints.ResolvedEndpoint, error) { - endpointID, err := extractEndpointID(req) +type endpoint struct { + URL *url.URL + SigningName string + SigningRegion string +} + +// resolveEndpoint returns a resolved endpoint for either the configured URI or +// the AWS target service and region. +// For a target operation, the appropriate AWS service resolver is used. +// Targets look like one of DynamoDB_$version.$operation, +// DynamoDBStreams_$version.$operation, or AmazonDAX$version.$operation. +// For example: DynamoDBStreams_20120810.ListStreams +func (e *Engine) resolveEndpoint(req *http.Request) (*endpoint, error) { + target, err := getTargetHeader(req) if err != nil { return nil, trace.Wrap(err) } - opts := func(opts *endpoints.Options) { - opts.ResolveUnknownService = true + + awsMeta := e.sessionCtx.Database.GetAWS() + + var re *endpoint + switch target := strings.ToLower(target); { + case strings.HasPrefix(target, "dynamodbstreams"): + re, err = resolveDynamoDBStreamsEndpoint(req.Context(), awsMeta.Region, e.UseFIPS) + case strings.HasPrefix(target, "dynamodb"): + re, err = resolveDynamoDBEndpoint(req.Context(), awsMeta.Region, awsMeta.AccountID, e.UseFIPS) + case strings.HasPrefix(target, "amazondax"): + re, err = resolveDaxEndpoint(req.Context(), awsMeta.Region, e.UseFIPS) + default: + return nil, trace.BadParameter("DynamoDB API target %q is not recognized", target) } - re, err := endpoints.DefaultResolver().EndpointFor(endpointID, e.sessionCtx.Database.GetAWS().Region, opts) if err != nil { return nil, trace.Wrap(err) } uri := e.sessionCtx.Database.GetURI() - if uri != "" && uri != apiaws.DynamoDBURIForRegion(e.sessionCtx.Database.GetAWS().Region) { + if uri != "" && uri != apiaws.DynamoDBURIForRegion(awsMeta.Region) { + // Add a temporary schema to make a valid URL for url.Parse. + if !strings.Contains(uri, "://") { + uri = "schema://" + uri + } + u, err := url.Parse(uri) + if err != nil { + return nil, trace.Wrap(err) + } // override the resolved endpoint URL with the user-configured URI. - re.URL = uri + re.URL = u } - if !strings.Contains(re.URL, "://") { - re.URL = "https://" + re.URL + // Force HTTPS + re.URL.Scheme = "https" + return re, nil +} + +func resolveDynamoDBStreamsEndpoint(ctx context.Context, region string, useFIPS bool) (*endpoint, error) { + params := dynamodbstreams.EndpointParameters{ + Region: aws.String(region), + UseFIPS: aws.Bool(useFIPS), } - return &re, nil + ep, err := dynamodbstreams.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) + if err != nil { + return nil, trace.Wrap(err) + } + return &endpoint{ + URL: &ep.URI, + SigningRegion: region, + // DynamoDB Streams uses the same signing name as DynamoDB. + SigningName: "dynamodb", + }, nil } -// rewriteRequest clones a request, modifies the clone to rewrite its URL, and returns the modified request clone. -func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.ResolvedEndpoint, body []byte) (*http.Request, error) { - resolvedURL, err := url.Parse(re.URL) +func resolveDynamoDBEndpoint(ctx context.Context, region, accountID string, useFIPS bool) (*endpoint, error) { + params := dynamodb.EndpointParameters{ + Region: aws.String(region), + // Preferred means if we have an account ID available, then use an + // account ID based endpoint. + // We should always have the account ID available anyway. + // If we didn't then it would just resolve the regional endpoint like + // dynamodb..amazonaws.com. + // AWS documents that account-based routing provides better request + // performance for some services. + // See: https://docs.aws.amazon.com/sdkref/latest/guide/feature-account-endpoints.html + AccountIdEndpointMode: aws.String(aws.AccountIDEndpointModePreferred), + UseFIPS: aws.Bool(useFIPS), + } + if accountID != "" { + params.AccountId = aws.String(accountID) + } + ep, err := dynamodb.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) if err != nil { return nil, trace.Wrap(err) } + return &endpoint{ + URL: &ep.URI, + SigningRegion: region, + SigningName: "dynamodb", + }, nil +} + +func resolveDaxEndpoint(ctx context.Context, region string, useFIPS bool) (*endpoint, error) { + params := dax.EndpointParameters{ + Region: aws.String(region), + UseFIPS: aws.Bool(useFIPS), + } + ep, err := dax.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) + if err != nil { + return nil, trace.Wrap(err) + } + return &endpoint{ + URL: &ep.URI, + SigningRegion: region, + SigningName: "dax", + }, nil +} + +// rewriteRequest clones a request, modifies the clone to rewrite its URL, and returns the modified request clone. +func rewriteRequest(ctx context.Context, r *http.Request, re *endpoint, body []byte) (*http.Request, error) { reqCopy := r.Clone(ctx) // set url and host header to match the resolved endpoint. - reqCopy.URL = resolvedURL - reqCopy.Host = resolvedURL.Host + reqCopy.URL = re.URL + reqCopy.Host = re.URL.Host if body == nil { // no body is fine, skip copying it. return reqCopy, nil @@ -388,42 +473,13 @@ func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.Resolved return reqCopy, nil } -// extractEndpointID extracts the AWS endpoint ID from the request header X-Amz-Target. -func extractEndpointID(req *http.Request) (string, error) { +// getTargetHeader gets the X-Amz-Target header or returns an error if it is not +// present, as we rely on this header for endpoint resolution. +// See X-Amz-Target: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.LowLevelAPI.html +func getTargetHeader(req *http.Request) (string, error) { target := req.Header.Get(libaws.AmzTargetHeader) if target == "" { return "", trace.BadParameter("missing %q header in http request", libaws.AmzTargetHeader) } - endpointID, err := endpointIDForTarget(target) - return endpointID, trace.Wrap(err) -} - -// endpointIDForTarget converts a target operation into the appropriate the AWS endpoint ID. -// Target looks like one of DynamoDB_$version.$operation, DynamoDBStreams_$version.$operation, AmazonDAX$version.$operation, -// for example: DynamoDBStreams_20120810.ListStreams -// See X-Amz-Target: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.LowLevelAPI.html -func endpointIDForTarget(target string) (string, error) { - t := strings.ToLower(target) - switch { - case strings.HasPrefix(t, "dynamodbstreams"): - return dynamodbstreams.EndpointsID, nil - case strings.HasPrefix(t, "dynamodb"): - return dynamodb.EndpointsID, nil - case strings.HasPrefix(t, "amazondax"): - return dax.EndpointsID, nil - default: - return "", trace.BadParameter("DynamoDB API target %q is not recognized", target) - } -} - -// getURLHostname parses a URL to extract its hostname. -func getURLHostname(uri string) (string, error) { - if !strings.Contains(uri, "://") { - uri = "schema://" + uri - } - parsed, err := url.Parse(uri) - if err != nil { - return "", trace.Wrap(err) - } - return parsed.Hostname(), nil + return target, nil } diff --git a/lib/srv/db/dynamodb/engine_test.go b/lib/srv/db/dynamodb/engine_test.go index faeaf0536ed1d..57b1fe170a409 100644 --- a/lib/srv/db/dynamodb/engine_test.go +++ b/lib/srv/db/dynamodb/engine_test.go @@ -38,7 +38,8 @@ func TestResolveEndpoint(t *testing.T) { desc string target string // from X-Amz-Target in requests region string - wantEndpointID string + useFIPS bool + unsetAccountID bool wantSigningName string wantURL string wantErrMsg string @@ -47,15 +48,21 @@ func TestResolveEndpoint(t *testing.T) { desc: "dynamodb target in us west", target: "DynamoDB_20120810.Scan", region: "us-west-1", - wantEndpointID: "dynamodb", + wantSigningName: "dynamodb", + wantURL: "https://123456789012.ddb.us-west-1.amazonaws.com", + }, + { + desc: "dynamodb target in us west with no account id", + target: "DynamoDB_20120810.Scan", + region: "us-west-1", wantSigningName: "dynamodb", wantURL: "https://dynamodb.us-west-1.amazonaws.com", + unsetAccountID: true, }, { desc: "dynamodb target in china", target: "DynamoDB_20120810.Scan", region: "cn-north-1", - wantEndpointID: "dynamodb", wantSigningName: "dynamodb", wantURL: "https://dynamodb.cn-north-1.amazonaws.com.cn", }, @@ -63,7 +70,6 @@ func TestResolveEndpoint(t *testing.T) { desc: "dynamodb streams target in us west", target: "DynamoDBStreams_20120810.ListStreams", region: "us-west-1", - wantEndpointID: "streams.dynamodb", wantSigningName: "dynamodb", wantURL: "https://streams.dynamodb.us-west-1.amazonaws.com", }, @@ -71,7 +77,6 @@ func TestResolveEndpoint(t *testing.T) { desc: "dynamodb streams target in china", target: "DynamoDBStreams_20120810.ListStreams", region: "cn-north-1", - wantEndpointID: "streams.dynamodb", wantSigningName: "dynamodb", wantURL: "https://streams.dynamodb.cn-north-1.amazonaws.com.cn", }, @@ -79,7 +84,6 @@ func TestResolveEndpoint(t *testing.T) { desc: "dax target in us west", target: "AmazonDAXV3.ListTags", region: "us-west-1", - wantEndpointID: "dax", wantSigningName: "dax", wantURL: "https://dax.us-west-1.amazonaws.com", }, @@ -87,10 +91,33 @@ func TestResolveEndpoint(t *testing.T) { desc: "dax target in china", target: "AmazonDAXV3.ListTags", region: "cn-north-1", - wantEndpointID: "dax", wantSigningName: "dax", wantURL: "https://dax.cn-north-1.amazonaws.com.cn", }, + { + desc: "dynamodb target in us west with FIPS required", + target: "DynamoDB_20120810.Scan", + region: "us-west-1", + wantSigningName: "dynamodb", + wantURL: "https://dynamodb-fips.us-west-1.amazonaws.com", + useFIPS: true, + }, + { + desc: "dynamodb streams target in us west with FIPS required", + target: "DynamoDBStreams_20120810.ListStreams", + region: "us-west-1", + wantSigningName: "dynamodb", + wantURL: "https://streams.dynamodb-fips.us-west-1.amazonaws.com", + useFIPS: true, + }, + { + desc: "dax target in us west with FIPS required", + target: "AmazonDAXV3.ListTags", + region: "us-west-1", + wantSigningName: "dax", + wantURL: "https://dax-fips.us-west-1.amazonaws.com", + useFIPS: true, + }, { desc: "unrecognizable target", target: "DDB.Scan", @@ -105,25 +132,19 @@ func TestResolveEndpoint(t *testing.T) { req := &http.Request{Header: make(http.Header)} req.Header.Set(libaws.AmzTargetHeader, tt.target) - // check that the correct endpoint ID is extracted. - endpointID, err := extractEndpointID(req) - if tt.wantErrMsg != "" { - require.Error(t, err) - require.ErrorContains(t, err, tt.wantErrMsg) - return - } - require.Equal(t, tt.wantEndpointID, endpointID) - // check that the engine resolves the correct URL. db := &types.DatabaseV3{ Spec: types.DatabaseSpecV3{ URI: apiaws.DynamoDBURIForRegion(tt.region), AWS: types.AWS{ Region: tt.region, - AccountID: "12345", + AccountID: "123456789012", }, }, } + if tt.unsetAccountID { + db.Spec.AWS.AccountID = "" + } engine := &Engine{ EngineConfig: common.EngineConfig{ Log: slog.Default(), @@ -131,18 +152,26 @@ func TestResolveEndpoint(t *testing.T) { sessionCtx: &common.Session{ Database: db, }, + UseFIPS: tt.useFIPS, } re, err := engine.resolveEndpoint(req) + if tt.wantErrMsg != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErrMsg) + return + } require.NoError(t, err) - require.Equal(t, tt.wantURL, re.URL) + require.Equal(t, tt.wantURL, re.URL.String()) require.Equal(t, tt.wantSigningName, re.SigningName) + require.Equal(t, tt.region, re.SigningRegion) // now use a custom URI and check that it overrides the resolved URL. db.Spec.URI = "foo.com" re, err = engine.resolveEndpoint(req) require.NoError(t, err) - require.Equal(t, "https://foo.com", re.URL) + require.Equal(t, "https://foo.com", re.URL.String()) require.Equal(t, tt.wantSigningName, re.SigningName) + require.Equal(t, tt.region, re.SigningRegion) }) } } diff --git a/lib/srv/db/dynamodb/test.go b/lib/srv/db/dynamodb/test.go index 462dc743782ab..cf7661dc044b4 100644 --- a/lib/srv/db/dynamodb/test.go +++ b/lib/srv/db/dynamodb/test.go @@ -27,10 +27,9 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -41,7 +40,10 @@ import ( ) // Client alias for easier use. -type Client = dynamodb.DynamoDB +type Client struct { + *dynamodb.Client + HTTPClient *http.Client +} // ClientOptionsParams is a struct for client configuration options. type ClientOptionsParams struct { @@ -54,19 +56,19 @@ type ClientOptions func(*ClientOptionsParams) // MakeTestClient returns DynamoDB client connection according to the provided // parameters. func MakeTestClient(_ context.Context, config common.TestClientConfig, opts ...ClientOptions) (*Client, error) { - provider := session.Must(session.NewSession(&aws.Config{ - Credentials: credentials.NewCredentials(&credentials.StaticProvider{Value: credentials.Value{ - AccessKeyID: "fakeClientKeyID", - SecretAccessKey: "fakeClientSecret", - }}), - Region: aws.String("local"), - })) - dynamoClient := dynamodb.New(provider, &aws.Config{ - Endpoint: aws.String("http://" + config.Address), - MaxRetries: aws.Int(0), // disable automatic retries in tests - HTTPClient: &http.Client{Timeout: 5 * time.Second}, + httpClt := &http.Client{Timeout: 5 * time.Second} + dynamoClient := dynamodb.New(dynamodb.Options{ + Region: "local", + Credentials: credentials.NewStaticCredentialsProvider( + "fakeClientKeyID", + "fakeClientSecret", + "", + ), + BaseEndpoint: aws.String("http://" + config.Address), + RetryMaxAttempts: 0, // disable automatic retries in tests + HTTPClient: httpClt, }) - return dynamoClient, nil + return &Client{Client: dynamoClient, HTTPClient: httpClt}, nil } // TestServerOption allows setting test server options. @@ -107,7 +109,7 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (*T mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - err := awsutils.VerifyAWSSignature(r, credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION")) + err := awsutils.VerifyAWSSignatureV2(r, credentials.NewStaticCredentialsProvider("AKIDl", "SECRET", "SESSION")) if err != nil { code := trace.ErrorToCode(err) body, _ := json.Marshal(jsonErr{ diff --git a/lib/srv/db/dynamodb_test.go b/lib/srv/db/dynamodb_test.go index 0ed6355983381..f7a2b259e110b 100644 --- a/lib/srv/db/dynamodb_test.go +++ b/lib/srv/db/dynamodb_test.go @@ -25,9 +25,8 @@ import ( "net/http" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - awsdynamodb "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/credentials" + awsdynamodb "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -36,6 +35,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/dynamodb" awsutils "github.com/gravitational/teleport/lib/utils/aws" + "github.com/gravitational/teleport/lib/utils/aws/migration" ) func registerTestDynamoDBEngine() { @@ -50,7 +50,9 @@ func newTestDynamoDBEngine(ec common.EngineConfig) common.Engine { RoundTrippers: make(map[string]http.RoundTripper), // inject mock AWS credentials. CredentialsGetter: awsutils.NewStaticCredentialsGetter( - credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION"), + migration.NewCredentialsAdapter( + credentials.NewStaticCredentialsProvider("AKIDl", "SECRET", "SESSION"), + ), ), } } @@ -127,14 +129,14 @@ func TestAccessDynamoDB(t *testing.T) { require.NoError(t, err) // Execute a dynamodb query. - out, err := clt.ListTables(&awsdynamodb.ListTablesInput{}) + out, err := clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) if test.wantErrMsg != "" { require.Error(t, err) require.ErrorContains(t, err, test.wantErrMsg) return } require.NoError(t, err) - require.ElementsMatch(t, mockTables, aws.StringValueSlice(out.TableNames)) + require.ElementsMatch(t, mockTables, out.TableNames) }) } } @@ -159,7 +161,7 @@ func TestAuditDynamoDB(t *testing.T) { require.NoError(t, err) // Execute a dynamodb query. - _, err = clt.ListTables(&awsdynamodb.ListTablesInput{}) + _, err = clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) require.Error(t, err) require.ErrorContains(t, err, "access to db denied") requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode) @@ -176,21 +178,21 @@ func TestAuditDynamoDB(t *testing.T) { require.NoError(t, err) t.Run("session starts and emits a request event", func(t *testing.T) { - _, err := clt.ListTables(&awsdynamodb.ListTablesInput{}) + _, err := clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) require.NoError(t, err) requireEvent(t, testCtx, libevents.DatabaseSessionStartCode) requireEvent(t, testCtx, libevents.DynamoDBRequestCode) }) t.Run("session ends when client closes the connection", func(t *testing.T) { - clt.Config.HTTPClient.CloseIdleConnections() + clt.HTTPClient.CloseIdleConnections() requireEvent(t, testCtx, libevents.DatabaseSessionEndCode) }) t.Run("session ends when local proxy closes the connection", func(t *testing.T) { // closing local proxy and canceling the context used to start it should trigger session end event. // without this cancel, the session will not end until the smaller of client_idle_timeout or the testCtx closes. - _, err := clt.ListTables(&awsdynamodb.ListTablesInput{}) + _, err := clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) require.NoError(t, err) requireEvent(t, testCtx, libevents.DatabaseSessionStartCode) requireEvent(t, testCtx, libevents.DynamoDBRequestCode)