Skip to content

Commit

Permalink
Make SDK span concurrent safe (#1300)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrAlias authored Nov 19, 2024
1 parent eb87959 commit c8ccb8b
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 20 deletions.
67 changes: 51 additions & 16 deletions sdk/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"fmt"
"reflect"
"runtime"
"sync"
"sync/atomic"
"time"

"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -54,14 +56,17 @@ var _ trace.Tracer = tracer{}

func (t tracer) Start(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
var psc trace.SpanContext
span := &span{sampled: true}
sampled := true
span := new(span)

// Ask eBPF for sampling decision and span context info.
t.start(ctx, span, &psc, &span.sampled, &span.spanContext)
t.start(ctx, span, &psc, &sampled, &span.spanContext)

span.sampled.Store(sampled)

ctx = trace.ContextWithSpan(ctx, span)

if span.sampled {
if sampled {
// Only build traces if sampled.
cfg := trace.NewSpanStartConfig(opts...)
span.traces, span.span = t.traces(ctx, name, cfg, span.spanContext, psc)
Expand Down Expand Up @@ -142,9 +147,10 @@ func spanKind(kind trace.SpanKind) telemetry.SpanKind {
type span struct {
noop.Span

sampled bool
spanContext trace.SpanContext
sampled atomic.Bool

mu sync.Mutex
traces *telemetry.Traces
span *telemetry.Span
}
Expand All @@ -153,21 +159,26 @@ func (s *span) SpanContext() trace.SpanContext {
if s == nil {
return trace.SpanContext{}
}
// s.spanContext is immutable, do not acquire lock s.mu.
return s.spanContext
}

func (s *span) IsRecording() bool {
if s == nil {
return false
}
return s.sampled

return s.sampled.Load()
}

func (s *span) SetStatus(c codes.Code, msg string) {
if s == nil || !s.sampled {
if s == nil || !s.sampled.Load() {
return
}

s.mu.Lock()
defer s.mu.Unlock()

if s.span.Status == nil {
s.span.Status = new(telemetry.Status)
}
Expand All @@ -185,10 +196,13 @@ func (s *span) SetStatus(c codes.Code, msg string) {
}

func (s *span) SetAttributes(attrs ...attribute.KeyValue) {
if s == nil || !s.sampled {
if s == nil || !s.sampled.Load() {
return
}

s.mu.Lock()
defer s.mu.Unlock()

// TODO: handle attribute limits.

m := make(map[string]int)
Expand Down Expand Up @@ -273,10 +287,18 @@ func convAttrValue(value attribute.Value) telemetry.Value {
}

func (s *span) End(opts ...trace.SpanEndOption) {
if s == nil || !s.sampled {
if s == nil || !s.sampled.Swap(false) {
return
}

// s.end exists so the lock (s.mu) is not held while s.ended is called.
s.ended(s.end(opts))
}

func (s *span) end(opts []trace.SpanEndOption) []byte {
s.mu.Lock()
defer s.mu.Unlock()

cfg := trace.NewSpanEndConfig(opts...)
if t := cfg.Timestamp(); !t.IsZero() {
s.span.EndTime = cfg.Timestamp()
Expand All @@ -285,10 +307,7 @@ func (s *span) End(opts ...trace.SpanEndOption) {
}

b, _ := json.Marshal(s.traces) // TODO: do not ignore this error.

s.sampled = false

s.ended(b)
return b
}

// Expected to be implemented in eBPF.
Expand All @@ -300,7 +319,7 @@ func (*span) ended(buf []byte) { ended(buf) }
var ended = func([]byte) {}

func (s *span) RecordError(err error, opts ...trace.EventOption) {
if s == nil || err == nil || !s.sampled {
if s == nil || err == nil || !s.sampled.Load() {
return
}

Expand All @@ -317,6 +336,9 @@ func (s *span) RecordError(err error, opts ...trace.EventOption) {
attrs = append(attrs, semconv.ExceptionStacktrace(string(buf[0:n])))
}

s.mu.Lock()
defer s.mu.Unlock()

s.addEvent(semconv.ExceptionEventName, cfg.Timestamp(), attrs)
}

Expand All @@ -330,14 +352,20 @@ func typeStr(i any) string {
}

func (s *span) AddEvent(name string, opts ...trace.EventOption) {
if s == nil || !s.sampled {
if s == nil || !s.sampled.Load() {
return
}

cfg := trace.NewEventConfig(opts...)

s.mu.Lock()
defer s.mu.Unlock()

s.addEvent(name, cfg.Timestamp(), cfg.Attributes())
}

// addEvent adds an event with name and attrs at tStamp to the span. The span
// lock (s.mu) needs to be held by the caller.
func (s *span) addEvent(name string, tStamp time.Time, attrs []attribute.KeyValue) {
// TODO: handle event limits.

Expand All @@ -349,10 +377,13 @@ func (s *span) addEvent(name string, tStamp time.Time, attrs []attribute.KeyValu
}

func (s *span) AddLink(link trace.Link) {
if s == nil || !s.sampled {
if s == nil || !s.sampled.Load() {
return
}

s.mu.Lock()
defer s.mu.Unlock()

// TODO: handle link limits.

s.span.Links = append(s.span.Links, convLink(link))
Expand All @@ -377,9 +408,13 @@ func convLink(link trace.Link) *telemetry.SpanLink {
}

func (s *span) SetName(name string) {
if s == nil || !s.sampled {
if s == nil || !s.sampled.Load() {
return
}

s.mu.Lock()
defer s.mu.Unlock()

s.span.Name = name
}

Expand Down
100 changes: 96 additions & 4 deletions sdk/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"encoding/json"
"errors"
"math"
"strconv"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -144,7 +146,7 @@ func TestSpanCreation(t *testing.T) {
Eval: func(t *testing.T, _ context.Context, s *span) {
assertTracer(s.traces)

assert.True(t, s.sampled, "not sampled by default.")
assert.True(t, s.sampled.Load(), "not sampled by default.")
},
},
{
Expand Down Expand Up @@ -195,7 +197,7 @@ func TestSpanCreation(t *testing.T) {
}
},
Eval: func(t *testing.T, _ context.Context, s *span) {
assert.False(t, s.sampled, "sampled")
assert.False(t, s.sampled.Load(), "sampled")
},
},
{
Expand Down Expand Up @@ -319,7 +321,7 @@ func TestSpanEnd(t *testing.T) {
s := spanBuilder{}.Build()
s.End(test.Options...)

assert.False(t, s.sampled, "ended span should not be sampled")
assert.False(t, s.sampled.Load(), "ended span should not be sampled")
require.NotNil(t, buf, "no span data emitted")

var traces telemetry.Traces
Expand Down Expand Up @@ -489,7 +491,8 @@ type spanBuilder struct {

func (b spanBuilder) Build() *span {
tracer := new(tracer)
s := &span{sampled: !b.NotSampled, spanContext: b.SpanContext}
s := &span{spanContext: b.SpanContext}
s.sampled.Store(!b.NotSampled)
s.traces, s.span = tracer.traces(
context.Background(),
b.Name,
Expand All @@ -500,3 +503,92 @@ func (b spanBuilder) Build() *span {

return s
}

func TestSpanConcurrentSafe(t *testing.T) {
t.Parallel()

const (
nTracers = 2
nSpans = 2
nGoroutine = 10
)

runSpan := func(s trace.Span) <-chan struct{} {
done := make(chan struct{})
go func(span trace.Span) {
defer close(done)

var wg sync.WaitGroup
for i := 0; i < nGoroutine; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()

_ = s.IsRecording()
_ = s.SpanContext()
_ = s.TracerProvider()

s.AddEvent("event")
s.AddLink(trace.Link{})
s.RecordError(errors.New("err"))
s.SetStatus(codes.Error, "error")
s.SetName("span" + strconv.Itoa(n))
s.SetAttributes(attribute.Bool("key", true))

s.End()
}(i)
}

wg.Wait()
}(s)
return done
}

runTracer := func(tr trace.Tracer) <-chan struct{} {
done := make(chan struct{})
go func(tracer trace.Tracer) {
defer close(done)

ctx := context.Background()

var wg sync.WaitGroup
for i := 0; i < nSpans; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
_, s := tracer.Start(ctx, "span"+strconv.Itoa(n))
<-runSpan(s)
}(i)
}

wg.Wait()
}(tr)
return done
}

run := func(tp trace.TracerProvider) <-chan struct{} {
done := make(chan struct{})
go func(provider trace.TracerProvider) {
defer close(done)

var wg sync.WaitGroup
for i := 0; i < nTracers; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
<-runTracer(provider.Tracer("tracer" + strconv.Itoa(n)))
}(i)
}

wg.Wait()
}(tp)
return done
}

assert.NotPanics(t, func() {
done0, done1 := run(TracerProvider()), run(TracerProvider())

<-done0
<-done1
})
}

0 comments on commit c8ccb8b

Please sign in to comment.