diff --git a/collect/collect.go b/collect/collect.go index 17c6412654..608096ee61 100644 --- a/collect/collect.go +++ b/collect/collect.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "math" - "math/rand/v2" "os" "runtime" "sort" @@ -293,7 +291,7 @@ func (i *InMemCollector) checkAlloc(ctx context.Context) { tracesSent := generics.NewSet[string]() // Send the traces we can't keep. for _, trace := range allTraces { - if !i.IsMyTrace(trace.ID()) { + if _, ok := i.IsMyTrace(trace.ID()); !ok { i.Logger.Debug().WithFields(map[string]interface{}{ "trace_id": trace.ID(), }).Logf("cannot eject trace that does not belong to this peer") @@ -573,7 +571,7 @@ func (i *InMemCollector) sendExpiredTracesInCache(ctx context.Context, now time. traceTimeout := i.Config.GetTracesConfig().GetTraceTimeout() var orphanTraceCount int traces := i.cache.TakeExpiredTraces(now, int(i.Config.GetTracesConfig().MaxExpiredTraces), func(t *types.Trace) bool { - if i.IsMyTrace(t.ID()) { + if _, ok := i.IsMyTrace(t.ID()); ok { return true } @@ -640,6 +638,8 @@ func (i *InMemCollector) sendExpiredTracesInCache(ctx context.Context, now time. TraceID: trace.ID(), Event: types.Event{ Context: trace.GetSpans()[0].Context, + APIKey: trace.APIKey, + Dataset: trace.Dataset, }, }, trace, i.Sharder.WhichShard(trace.ID()))) } @@ -656,6 +656,18 @@ func (i *InMemCollector) processSpan(ctx context.Context, sp *types.Span) { span.End() }() + targetShard, isMyTrace := i.IsMyTrace(sp.TraceID) + // if the span is a decision span and the trace no longer belong to us, we should not forward it to the peer + if !isMyTrace && sp.IsDecisionSpan() { + return + } + + // if trace locality is enabled, we should forward all spans to its correct peer + if i.Config.GetCollectionConfig().EnableTraceLocality && !isMyTrace { + i.PeerTransmission.EnqueueSpan(sp) + return + } + tcfg := i.Config.GetTracesConfig() trace := i.cache.Get(sp.TraceID) @@ -712,22 +724,17 @@ func (i *InMemCollector) processSpan(ctx context.Context, sp *types.Span) { trace.AddSpan(sp) span.SetAttributes(attribute.String("disposition", "live_trace")) - // Figure out if we should handle this span locally or pass on to a peer var spanForwarded bool - if !i.Config.GetCollectionConfig().EnableTraceLocality { - // if this trace doesn't belong to us, we should forward a decision span to its decider - targetShard := i.Sharder.WhichShard(trace.ID()) - if !targetShard.Equals(i.Sharder.MyShard()) && !sp.IsDecisionSpan() { - i.Metrics.Increment("incoming_router_peer") - i.Logger.Debug(). - WithString("peer", targetShard.GetAddress()). - Logf("Sending span to peer") - - dc := i.createDecisionSpan(sp, trace, targetShard) - - i.PeerTransmission.EnqueueEvent(dc) - spanForwarded = true - } + // if this trace doesn't belong to us and it's not in sent state, we should forward a decision span to its decider + if !trace.Sent && !isMyTrace { + i.Metrics.Increment("incoming_router_peer") + i.Logger.Debug(). + Logf("Sending span to peer") + + dc := i.createDecisionSpan(sp, trace, targetShard) + + i.PeerTransmission.EnqueueEvent(dc) + spanForwarded = true } // we may override these values in conditions below @@ -1329,107 +1336,6 @@ func (i *InMemCollector) sendTraces() { } } -type redistributeNotifier struct { - clock clockwork.Clock - logger logger.Logger - initialDelay time.Duration - maxAttempts int - maxDelay time.Duration - metrics metrics.Metrics - - reset chan struct{} - done chan struct{} - triggered chan struct{} - once sync.Once -} - -func newRedistributeNotifier(logger logger.Logger, met metrics.Metrics, clock clockwork.Clock) *redistributeNotifier { - r := &redistributeNotifier{ - initialDelay: 3 * time.Second, - maxDelay: 30 * time.Second, - maxAttempts: 5, - done: make(chan struct{}), - clock: clock, - logger: logger, - metrics: met, - triggered: make(chan struct{}), - reset: make(chan struct{}), - } - - return r -} - -func (r *redistributeNotifier) Notify() <-chan struct{} { - return r.triggered -} - -func (r *redistributeNotifier) Reset() { - var started bool - r.once.Do(func() { - go r.run() - started = true - }) - - if started { - return - } - - select { - case r.reset <- struct{}{}: - case <-r.done: - return - default: - r.logger.Debug().Logf("A trace redistribution is ongoing. Ignoring reset.") - } -} - -func (r *redistributeNotifier) Stop() { - close(r.done) -} - -func (r *redistributeNotifier) run() { - var attempts int - lastBackoff := r.initialDelay - for { - // if we've reached the max attempts, reset the backoff and attempts - // only when the reset signal is received. - if attempts >= r.maxAttempts { - r.metrics.Gauge("trace_redistribution_count", 0) - <-r.reset - lastBackoff = r.initialDelay - attempts = 0 - } - select { - case <-r.done: - return - case r.triggered <- struct{}{}: - } - - attempts++ - r.metrics.Gauge("trace_redistribution_count", attempts) - - // Calculate the backoff interval using exponential backoff with a base time. - backoff := time.Duration(math.Min(float64(lastBackoff)*2, float64(r.maxDelay))) - // Add jitter to the backoff to avoid retry collisions. - jitter := time.Duration(rand.Float64() * float64(backoff) * 0.5) - nextBackoff := backoff + jitter - lastBackoff = nextBackoff - - timer := r.clock.NewTimer(nextBackoff) - select { - case <-timer.Chan(): - timer.Stop() - case <-r.reset: - lastBackoff = r.initialDelay - attempts = 0 - timer.Stop() - case <-r.done: - timer.Stop() - return - } - } -} - func (i *InMemCollector) signalKeptTraceDecisions(ctx context.Context, msg string) { if len(msg) == 0 { return @@ -1588,13 +1494,15 @@ func (i *InMemCollector) makeDecision(trace *types.Trace, sendReason string) (*T return &td, nil } -func (i *InMemCollector) IsMyTrace(traceID string) bool { +func (i *InMemCollector) IsMyTrace(traceID string) (sharder.Shard, bool) { // if trace locality is enabled, we should always process the trace if i.Config.GetCollectionConfig().EnableTraceLocality { - return true + return i.Sharder.MyShard(), true } - return i.Sharder.WhichShard(traceID).Equals(i.Sharder.MyShard()) + targeShard := i.Sharder.WhichShard(traceID) + + return targeShard, i.Sharder.MyShard().Equals(targeShard) } func (i *InMemCollector) publishTraceDecision(ctx context.Context, td TraceDecision) { diff --git a/collect/collect_test.go b/collect/collect_test.go index 972db401ba..dd0cf509a3 100644 --- a/collect/collect_test.go +++ b/collect/collect_test.go @@ -57,6 +57,8 @@ func newTestCollector(conf config.Config, transmission transmit.Transmission, pe Metrics: s, } localPubSub.Start() + redistributeNotifier := newRedistributeNotifier(&logger.NullLogger{}, &metrics.NullMetrics{}, clock) + redistributeNotifier.initialDelay = 2 * time.Millisecond c := &InMemCollector{ Config: conf, @@ -90,7 +92,7 @@ func newTestCollector(conf config.Config, transmission transmit.Transmission, pe TraceIDs: peerTraceIDs, }, }, - redistributeTimer: newRedistributeNotifier(&logger.NullLogger{}, &metrics.NullMetrics{}, clock), + redistributeTimer: redistributeNotifier, } if !conf.GetCollectionConfig().EnableTraceLocality { @@ -1748,9 +1750,20 @@ func TestRedistributeTraces(t *testing.T) { } coll.Sharder = s + coll.incoming = make(chan *types.Span, 5) + coll.fromPeer = make(chan *types.Span, 5) + coll.outgoingTraces = make(chan sendableTrace, 5) + coll.datasetSamplers = make(map[string]sample.Sampler) + + c := cache.NewInMemCache(3, &metrics.NullMetrics{}, &logger.NullLogger{}) + coll.cache = c + stc, err := newCache() + assert.NoError(t, err, "lru cache should start") + coll.sampleTraceCache = stc + + go coll.collect() + go coll.sendTraces() - err := coll.Start() - assert.NoError(t, err) defer coll.Stop() dataset := "aoeu" @@ -1803,7 +1816,7 @@ func TestRedistributeTraces(t *testing.T) { coll.mutex.Lock() coll.cache.Set(trace) coll.mutex.Unlock() - coll.Peers.RegisterUpdatedPeersCallback(coll.redistributeTimer.Reset) + coll.redistributeTimer.Reset() peerEvents := peerTransmission.GetBlock(1) assert.Len(t, peerEvents, 1) diff --git a/collect/trace_redistributer.go b/collect/trace_redistributer.go new file mode 100644 index 0000000000..ad9ccfa73f --- /dev/null +++ b/collect/trace_redistributer.go @@ -0,0 +1,111 @@ +package collect + +import ( + "math/rand/v2" + "sync" + "time" + + "github.com/honeycombio/refinery/logger" + "github.com/honeycombio/refinery/metrics" + "github.com/jonboulle/clockwork" +) + +type redistributeNotifier struct { + clock clockwork.Clock + logger logger.Logger + initialDelay time.Duration + maxDelay float64 + metrics metrics.Metrics + + reset chan struct{} + done chan struct{} + triggered chan struct{} + once sync.Once +} + +func newRedistributeNotifier(logger logger.Logger, met metrics.Metrics, clock clockwork.Clock) *redistributeNotifier { + r := &redistributeNotifier{ + initialDelay: 3 * time.Second, + maxDelay: float64(30 * time.Second), + done: make(chan struct{}), + clock: clock, + logger: logger, + metrics: met, + triggered: make(chan struct{}), + reset: make(chan struct{}), + } + + return r +} + +func (r *redistributeNotifier) Notify() <-chan struct{} { + return r.triggered +} + +func (r *redistributeNotifier) Reset() { + var started bool + r.once.Do(func() { + go r.run() + started = true + }) + + if started { + return + } + + select { + case r.reset <- struct{}{}: + case <-r.done: + return + default: + r.logger.Debug().Logf("A trace redistribution is ongoing. Ignoring reset.") + } +} + +func (r *redistributeNotifier) Stop() { + close(r.done) +} + +// run runs the redistribution notifier loop. +// It will notify the trigger channel when it's time to redistribute traces, which we want +// to happen when the number of peers changes. But we don't want to do it immediately, +// because peer membership changes often happen in bunches, so we wait a while +// before triggering the redistribution. +func (r *redistributeNotifier) run() { + currentDelay := r.calculateDelay(r.initialDelay) + + // start a back off timer with the initial delay + timer := r.clock.NewTimer(currentDelay) + for { + select { + case <-r.done: + timer.Stop() + return + case <-r.reset: + // reset the delay timer when we receive a reset signal. + currentDelay = r.calculateDelay(r.initialDelay) + if !timer.Stop() { + // drain the timer channel + select { + case <-timer.Chan(): + default: + } + } + timer.Reset(currentDelay) + case <-timer.Chan(): + select { + case <-r.done: + return + case r.triggered <- struct{}{}: + } + } + } +} + +// calculateBackoff calculates the backoff interval for the next redistribution cycle. +// It uses exponential backoff with a base time and adds jitter to avoid retry collisions. +func (r *redistributeNotifier) calculateDelay(currentDelay time.Duration) time.Duration { + // Add jitter to the backoff to avoid retry collisions. + jitter := time.Duration(rand.Float64() * float64(currentDelay) * 0.5) + return currentDelay + jitter +} diff --git a/collect/trace_redistributer_test.go b/collect/trace_redistributer_test.go new file mode 100644 index 0000000000..03253a4c4f --- /dev/null +++ b/collect/trace_redistributer_test.go @@ -0,0 +1,55 @@ +package collect + +import ( + "testing" + "time" + + "github.com/honeycombio/refinery/metrics" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" +) + +// TestRedistributeNotifier tests the timer logic in redistributeNotifier +func TestRedistributeNotifier(t *testing.T) { + // Set up the notifier with a mock clock + + clock := clockwork.NewFakeClock() + r := &redistributeNotifier{ + clock: clock, + initialDelay: 50 * time.Millisecond, // Set the initial delay + metrics: &metrics.NullMetrics{}, + reset: make(chan struct{}), + done: make(chan struct{}), + triggered: make(chan struct{}, 4), // Buffered to allow easier testing + } + + defer r.Stop() + + go r.run() + + assert.Len(t, r.triggered, 0) + // Test that the notifier is not triggered before the initial delay + clock.BlockUntil(1) + clock.Advance(20 * time.Millisecond) + assert.Len(t, r.triggered, 0) + + // Test that the notifier is triggered after the initial delay + currentBackOff := r.initialDelay + clock.BlockUntil(1) + currentBackOff = r.calculateDelay(currentBackOff) + clock.Advance(currentBackOff + 100*time.Millisecond) // Advance the clock by the backoff time plus a little extra + + // Check that the notifier has been triggered + assert.Eventually(t, func() bool { + return len(r.triggered) == 1 + }, 200*time.Millisecond, 10*time.Millisecond, "Expected to be triggered %d times", 1) + + // Once we receive another reset signal, the timer should start again + r.Reset() + clock.BlockUntil(1) + clock.Advance(500 * time.Millisecond) + assert.Eventually(t, func() bool { + return len(r.triggered) == 2 + }, 200*time.Millisecond, 10*time.Millisecond, "Expected to be triggered 4 times") + +}