Skip to content

Commit

Permalink
add additional tracing info to message attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartnelson3 committed Apr 15, 2021
1 parent 24984a7 commit 0de1b64
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 31 deletions.
2 changes: 1 addition & 1 deletion module/apmawssdkgo/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func build(req *request.Request) {
if req.ClientInfo.ServiceName != serviceSQS {
return
}
addMessageAttributes(req, span)
addMessageAttributes(req, span, tx.ShouldPropagateLegacyHeader())
}

func send(req *request.Request) {
Expand Down
26 changes: 22 additions & 4 deletions module/apmawssdkgo/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (s *apmSQS) setAdditional(span *apm.Span) {

// addMessageAttributes adds message attributes to `SendMessage` and
// `SendMessageBatch` RPC calls. Other SQS RPC calls are ignored.
func addMessageAttributes(req *request.Request, span *apm.Span) {
func addMessageAttributes(req *request.Request, span *apm.Span, propagateLegacyHeader bool) {
switch req.Operation.Name {
case "SendMessage", "SendMessageBatch":
break
Expand All @@ -92,20 +92,38 @@ func addMessageAttributes(req *request.Request, span *apm.Span) {
DataType: aws.String("String"),
StringValue: aws.String(apmhttp.FormatTraceparentHeader(traceContext)),
}

tracestate := traceContext.State.String()
if req.Operation.Name == "SendMessage" {
input, ok := req.Params.(*sqs.SendMessageInput)
if !ok {
return
}
input.MessageAttributes["traceContext"] = msgAttr
setTracingAttributes(input.MessageAttributes, msgAttr, tracestate, propagateLegacyHeader)
} else if req.Operation.Name == "SendMessageBatch" {
input, ok := req.Params.(*sqs.SendMessageBatchInput)
if !ok {
return
}
for _, entry := range input.Entries {
entry.MessageAttributes["traceContext"] = msgAttr
setTracingAttributes(entry.MessageAttributes, msgAttr, tracestate, propagateLegacyHeader)
}
}
}

func setTracingAttributes(
attrs map[string]*sqs.MessageAttributeValue,
value *sqs.MessageAttributeValue,
tracestate string,
propagateLegacyHeader bool,
) {
attrs[apmhttp.W3CTraceparentHeader] = value
if propagateLegacyHeader {
attrs[apmhttp.ElasticTraceparentHeader] = value
}
if tracestate != "" {
attrs[apmhttp.TracestateHeader] = &sqs.MessageAttributeValue{
DataType: aws.String("String"),
StringValue: aws.String(tracestate),
}
}
}
Expand Down
58 changes: 32 additions & 26 deletions module/apmawssdkgo/sqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,19 @@
package apmawssdkgo // import "go.elastic.co/apm/module/apmawssdkgo"

import (
"bytes"
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"

"go.elastic.co/apm/apmtest"
"go.elastic.co/apm/module/apmhttp"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -131,13 +130,11 @@ func TestSQS(t *testing.T) {
},
},
} {
buf := new(bytes.Buffer)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if tc.hasError {
w.WriteHeader(http.StatusInternalServerError)
return
}
io.Copy(buf, r.Body)
}))
defer ts.Close()

Expand All @@ -150,6 +147,13 @@ func TestSQS(t *testing.T) {

session := session.Must(session.NewSession(cfg))
wrapped := WrapSession(session)
if tc.hasTraceContext {
wrapped.Handlers.Build.PushBackNamed(request.NamedHandler{
Name: "spy_message_attrs_added",
Fn: testTracingAttributes(t),
})
}

svc := sqs.New(wrapped)

tx, spans, errors := apmtest.WithTransaction(func(ctx context.Context) {
Expand All @@ -162,27 +166,6 @@ func TestSQS(t *testing.T) {
return
}

if tc.hasTraceContext {
kvs := make(map[string]string)
var traceContextPresent bool
for _, kvPair := range strings.Split(buf.String(), "&") {
kv := strings.Split(kvPair, "=")
kvs[kv[0]] = kv[1]
}
if v, ok := kvs["MessageAttribute.2.Name"]; ok {
traceContextPresent = true
assert.Equal(t, "traceContext", v)
assert.NotEmpty(t, kvs["MessageAttribute.2.Value.StringValue"])
}
if v, ok := kvs["SendMessageBatchRequestEntry.1.MessageAttribute.2.Name"]; ok {
traceContextPresent = true
assert.Equal(t, "traceContext", v)
assert.NotEmpty(t, kvs["SendMessageBatchRequestEntry.1.MessageAttribute.2.Value.StringValue"])
}
require.True(t, traceContextPresent)
}
buf.Reset()

require.Len(t, spans, 1)
span := spans[0]

Expand Down Expand Up @@ -213,3 +196,26 @@ func TestSQS(t *testing.T) {
assert.Equal(t, tx.ID, span.ParentID)
}
}

func testTracingAttributes(t *testing.T) func(*request.Request) {
return func(req *request.Request) {
testAttrs := func(t *testing.T, attrs map[string]*sqs.MessageAttributeValue) {
assert.Contains(t, attrs, apmhttp.W3CTraceparentHeader)
assert.Contains(t, attrs, apmhttp.ElasticTraceparentHeader)
assert.Contains(t, attrs, apmhttp.TracestateHeader)
}
if req.Operation.Name == "SendMessage" {
input, ok := req.Params.(*sqs.SendMessageInput)
require.True(t, ok)
testAttrs(t, input.MessageAttributes)
} else if req.Operation.Name == "SendMessageBatch" {
input, ok := req.Params.(*sqs.SendMessageBatchInput)
require.True(t, ok)
for _, entry := range input.Entries {
testAttrs(t, entry.MessageAttributes)
}
} else {
t.Fail()
}
}
}

0 comments on commit 0de1b64

Please sign in to comment.