Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nhulston committed Oct 8, 2024
1 parent 384b46f commit d6e6c66
Show file tree
Hide file tree
Showing 4 changed files with 729 additions and 3 deletions.
140 changes: 137 additions & 3 deletions contrib/aws/aws-sdk-go-v2/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package aws
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -24,12 +25,13 @@ import (
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/eventbridge"
eventBridgeTypes "github.com/aws/aws-sdk-go-v2/service/eventbridge/types"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/sfn"
"github.com/aws/aws-sdk-go-v2/service/sns"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/aws/smithy-go/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -281,6 +283,66 @@ func TestAppendMiddlewareSqsReceiveMessage(t *testing.T) {
}
}

func TestAppendMiddlewareSqsSendMessage(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

expectedStatusCode := 200
server := mockAWS(expectedStatusCode)
defer server.Close()

resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
PartitionID: "aws",
URL: server.URL,
SigningRegion: "eu-west-1",
}, nil
})

awsCfg := aws.Config{
Region: "eu-west-1",
Credentials: aws.AnonymousCredentials{},
EndpointResolver: resolver,
}

AppendMiddleware(&awsCfg)

sqsClient := sqs.NewFromConfig(awsCfg)
sendMessageInput := &sqs.SendMessageInput{
MessageBody: aws.String("test message"),
QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"),
}
_, err := sqsClient.SendMessage(context.Background(), sendMessageInput)
require.NoError(t, err)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)

s := spans[0]
assert.Equal(t, "SQS.request", s.OperationName())
assert.Equal(t, "SendMessage", s.Tag("aws.operation"))
assert.Equal(t, "SQS", s.Tag("aws.service"))
assert.Equal(t, "MyQueueName", s.Tag("queuename"))
assert.Equal(t, "SQS.SendMessage", s.Tag(ext.ResourceName))
assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName))

// Check for trace context injection
assert.NotNil(t, sendMessageInput.MessageAttributes)
assert.Contains(t, sendMessageInput.MessageAttributes, "_datadog")
ddAttr := sendMessageInput.MessageAttributes["_datadog"]
assert.Equal(t, "String", *ddAttr.DataType)
assert.NotEmpty(t, *ddAttr.StringValue)

// Decode and verify the injected trace context
var traceContext map[string]string
err = json.Unmarshal([]byte(*ddAttr.StringValue), &traceContext)
assert.NoError(t, err)
assert.Contains(t, traceContext, "x-datadog-trace-id")
assert.Contains(t, traceContext, "x-datadog-parent-id")
assert.NotEmpty(t, traceContext["x-datadog-trace-id"])
assert.NotEmpty(t, traceContext["x-datadog-parent-id"])
}

func TestAppendMiddlewareS3ListObjects(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -441,6 +503,22 @@ func TestAppendMiddlewareSnsPublish(t *testing.T) {
assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL))
assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component))
assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind))

// Check for trace context injection
assert.NotNil(t, tt.publishInput.MessageAttributes)
assert.Contains(t, tt.publishInput.MessageAttributes, "_datadog")
ddAttr := tt.publishInput.MessageAttributes["_datadog"]
assert.Equal(t, "String", *ddAttr.DataType)
assert.NotEmpty(t, *ddAttr.StringValue)

// Decode and verify the injected trace context
var traceContext map[string]string
err := json.Unmarshal([]byte(*ddAttr.StringValue), &traceContext)
assert.NoError(t, err)
assert.Contains(t, traceContext, "x-datadog-trace-id")
assert.Contains(t, traceContext, "x-datadog-parent-id")
assert.NotEmpty(t, traceContext["x-datadog-trace-id"])
assert.NotEmpty(t, traceContext["x-datadog-parent-id"])
})
}
}
Expand Down Expand Up @@ -657,6 +735,62 @@ func TestAppendMiddlewareEventBridgePutRule(t *testing.T) {
}
}

func TestAppendMiddlewareEventBridgePutEvents(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

expectedStatusCode := 200
server := mockAWS(expectedStatusCode)
defer server.Close()

resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
PartitionID: "aws",
URL: server.URL,
SigningRegion: "eu-west-1",
}, nil
})

awsCfg := aws.Config{
Region: "eu-west-1",
Credentials: aws.AnonymousCredentials{},
EndpointResolver: resolver,
}

AppendMiddleware(&awsCfg)

eventbridgeClient := eventbridge.NewFromConfig(awsCfg)
putEventsInput := &eventbridge.PutEventsInput{
Entries: []eventBridgeTypes.PutEventsRequestEntry{
{
EventBusName: aws.String("my-event-bus"),
Detail: aws.String(`{"key": "value"}`),
},
},
}
eventbridgeClient.PutEvents(context.Background(), putEventsInput)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)

s := spans[0]
assert.Equal(t, "PutEvents", s.Tag("aws.operation"))
assert.Equal(t, "EventBridge.PutEvents", s.Tag(ext.ResourceName))

// Check for trace context injection
assert.Len(t, putEventsInput.Entries, 1)
entry := putEventsInput.Entries[0]
var detail map[string]interface{}
err := json.Unmarshal([]byte(*entry.Detail), &detail)
assert.NoError(t, err)
assert.Contains(t, detail, "_datadog")
ddData, ok := detail["_datadog"].(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, ddData, "x-datadog-start-time")
assert.Contains(t, ddData, "x-datadog-resource-name")
assert.Equal(t, "my-event-bus", ddData["x-datadog-resource-name"])
}

func TestAppendMiddlewareSfnDescribeStateMachine(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -971,8 +1105,8 @@ func TestMessagingNamingSchema(t *testing.T) {
_, err = sqsClient.SendMessage(ctx, msg)
require.NoError(t, err)

entry := types.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")}
batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []types.SendMessageBatchRequestEntry{entry}}
entry := sqsTypes.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")}
batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []sqsTypes.SendMessageBatchRequestEntry{entry}}
_, err = sqsClient.SendMessageBatch(ctx, batchMsg)
require.NoError(t, err)

Expand Down
207 changes: 207 additions & 0 deletions contrib/aws/internal/eventbridge/eventbridge_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package eventbridge

import (
"context"
"encoding/json"
"strconv"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/eventbridge"
"github.com/aws/aws-sdk-go-v2/service/eventbridge/types"
"github.com/aws/smithy-go/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)

type testCarrier struct {
m map[string]string
}

func (c *testCarrier) Set(key, val string) {
c.m[key] = val
}

func (c *testCarrier) ForeachKey(handler func(key, val string) error) error {
for k, v := range c.m {
if err := handler(k, v); err != nil {
return err
}
}
return nil
}

func TestEnrichOperation(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

ctx := context.Background()
_, ctx = tracer.StartSpanFromContext(ctx, "test-span")

input := middleware.InitializeInput{
Parameters: &eventbridge.PutEventsInput{
Entries: []types.PutEventsRequestEntry{
{
Detail: aws.String(`{"key": "value"}`),
EventBusName: aws.String("test-bus"),
},
{
Detail: aws.String(`{"another": "data"}`),
EventBusName: aws.String("test-bus-2"),
},
},
},
}

EnrichOperation(ctx, input, "PutEvents")

params, ok := input.Parameters.(*eventbridge.PutEventsInput)
require.True(t, ok)
require.Len(t, params.Entries, 2)

for _, entry := range params.Entries {
var detail map[string]interface{}
err := json.Unmarshal([]byte(*entry.Detail), &detail)
require.NoError(t, err)

assert.Contains(t, detail, datadogKey)
ddData, ok := detail[datadogKey].(map[string]interface{})
require.True(t, ok)

assert.Contains(t, ddData, startTimeKey)
assert.Contains(t, ddData, resourceNameKey)
assert.Equal(t, *entry.EventBusName, ddData[resourceNameKey])
}
}

func TestInjectTraceContext(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

ctx := context.Background()
span, ctx := tracer.StartSpanFromContext(ctx, "test-span")

tests := []struct {
name string
entry types.PutEventsRequestEntry
expected func(*testing.T, *types.PutEventsRequestEntry)
}{
{
name: "Inject into empty detail",
entry: types.PutEventsRequestEntry{
EventBusName: aws.String("test-bus"),
},
expected: func(t *testing.T, entry *types.PutEventsRequestEntry) {
assert.NotNil(t, entry.Detail)
var detail map[string]interface{}
err := json.Unmarshal([]byte(*entry.Detail), &detail)
require.NoError(t, err)
assert.Contains(t, detail, datadogKey)
},
},
{
name: "Inject into existing detail",
entry: types.PutEventsRequestEntry{
Detail: aws.String(`{"existing": "data"}`),
EventBusName: aws.String("test-bus"),
},
expected: func(t *testing.T, entry *types.PutEventsRequestEntry) {
var detail map[string]interface{}
err := json.Unmarshal([]byte(*entry.Detail), &detail)
require.NoError(t, err)
assert.Contains(t, detail, "existing")
assert.Equal(t, "data", detail["existing"])
assert.Contains(t, detail, datadogKey)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
injectTraceContext(ctx, &tt.entry)
tt.expected(t, &tt.entry)

var detail map[string]interface{}
err := json.Unmarshal([]byte(*tt.entry.Detail), &detail)
require.NoError(t, err)

ddData := detail[datadogKey].(map[string]interface{})
assert.Contains(t, ddData, startTimeKey)
assert.Contains(t, ddData, resourceNameKey)
assert.Equal(t, *tt.entry.EventBusName, ddData[resourceNameKey])

// Check that start time exists and is not empty
startTimeStr, ok := ddData[startTimeKey].(string)
assert.True(t, ok)
startTime, err := strconv.ParseInt(startTimeStr, 10, 64)
assert.NoError(t, err)
assert.Greater(t, startTime, int64(0))

var carrier testCarrier
carrier.m = make(map[string]string)
for k, v := range ddData {
if s, ok := v.(string); ok {
carrier.m[k] = s
}
}

extractedSpanContext, err := tracer.Extract(&carrier)
assert.NoError(t, err)
assert.Equal(t, span.Context().TraceID(), extractedSpanContext.TraceID())
assert.Equal(t, span.Context().SpanID(), extractedSpanContext.SpanID())
})
}
}

func TestInjectTraceContextSizeLimit(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

ctx := context.Background()
_, ctx = tracer.StartSpanFromContext(ctx, "test-span")

tests := []struct {
name string
entry types.PutEventsRequestEntry
expected func(*testing.T, *types.PutEventsRequestEntry)
}{
{
name: "Do not inject when payload is too large",
entry: types.PutEventsRequestEntry{
Detail: aws.String(`{"large": "` + strings.Repeat("a", maxSizeBytes-15) + `"}`),
EventBusName: aws.String("test-bus"),
},
expected: func(t *testing.T, entry *types.PutEventsRequestEntry) {
assert.GreaterOrEqual(t, len(*entry.Detail), maxSizeBytes-15)
assert.NotContains(t, *entry.Detail, datadogKey)
assert.True(t, strings.HasPrefix(*entry.Detail, `{"large": "`))
assert.True(t, strings.HasSuffix(*entry.Detail, `"}`))
},
},
{
name: "Inject when payload is just under the limit",
entry: types.PutEventsRequestEntry{
Detail: aws.String(`{"large": "` + strings.Repeat("a", maxSizeBytes-1000) + `"}`),
EventBusName: aws.String("test-bus"),
},
expected: func(t *testing.T, entry *types.PutEventsRequestEntry) {
assert.Less(t, len(*entry.Detail), maxSizeBytes)
var detail map[string]interface{}
err := json.Unmarshal([]byte(*entry.Detail), &detail)
require.NoError(t, err)
assert.Contains(t, detail, datadogKey)
assert.Contains(t, detail, "large")
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
injectTraceContext(ctx, &tt.entry)
tt.expected(t, &tt.entry)
})
}
}
Loading

0 comments on commit d6e6c66

Please sign in to comment.