diff --git a/collect/cache/cuckoo.go b/collect/cache/cuckoo.go index 4cf228127d..e6c4e4774d 100644 --- a/collect/cache/cuckoo.go +++ b/collect/cache/cuckoo.go @@ -56,47 +56,52 @@ func NewCuckooTraceChecker(capacity uint, m metrics.Metrics) *CuckooTraceChecker // To try to avoid blocking on Add, we have a goroutine that pulls from a // channel and adds to the filter. go func() { - for { - n := len(c.addch) - if n == 0 { - // if the channel is empty, wait for a bit - time.Sleep(AddQueueSleepTime) - continue + ticker := time.NewTicker(AddQueueSleepTime) + for range ticker.C { + // as long as there's anything still in the channel, keep trying to drain it + for len(c.addch) > 0 { + c.drain() } - c.drain() } }() return c } -// This function records all the traces that were in the channel at the -// start of the call. The idea is to add them all under a single lock. We -// tested limiting it so as to not hold the lock for too long, but it didn't -// seem to matter and it made the code more complicated. -// We track a histogram metric about lock time, though, so we can watch it. +// This function records all the traces that were in the channel at the start of +// the call. The idea is to add as many as possible under a single lock. We do +// limit our lock hold time to 1ms, so if we can't add them all in that time, we +// stop and let the next call pick up the rest. We track a histogram metric +// about lock time. func (c *CuckooTraceChecker) drain() { n := len(c.addch) if n == 0 { return } - lockStart := time.Now() c.mut.Lock() + // we don't start the timer until we have the lock, because we don't want to be counting + // the time we're waiting for the lock. + lockStart := time.Now() + timeout := time.NewTimer(1 * time.Millisecond) outer: for i := 0; i < n; i++ { select { case t := <-c.addch: - c.current.Insert([]byte(t)) + s := []byte(t) + c.current.Insert(s) // don't add anything to future if it doesn't exist yet if c.future != nil { - c.future.Insert([]byte(t)) + c.future.Insert(s) } + case <-timeout.C: + break outer default: // if the channel is empty, stop break outer } } c.mut.Unlock() + timeout.Stop() qlt := time.Since(lockStart) c.met.Histogram(AddQueueLockTime, qlt.Microseconds()) } diff --git a/collect/cache/cuckooSentCache.go b/collect/cache/cuckooSentCache.go index 20c0519e05..7c3b9bfa09 100644 --- a/collect/cache/cuckooSentCache.go +++ b/collect/cache/cuckooSentCache.go @@ -6,6 +6,7 @@ import ( lru "github.com/hashicorp/golang-lru/v2" "github.com/honeycombio/refinery/config" + "github.com/honeycombio/refinery/generics" "github.com/honeycombio/refinery/metrics" "github.com/honeycombio/refinery/types" ) @@ -139,9 +140,10 @@ func (t *cuckooDroppedRecord) Reason() uint { var _ TraceSentRecord = (*cuckooDroppedRecord)(nil) type cuckooSentCache struct { - kept *lru.Cache[string, *keptTraceCacheEntry] - dropped *CuckooTraceChecker - cfg config.SampleCacheConfig + kept *lru.Cache[string, *keptTraceCacheEntry] + dropped *CuckooTraceChecker + recentDroppedIDs *generics.SetWithTTL[string] + cfg config.SampleCacheConfig // The done channel is used to decide when to terminate the monitor // goroutine. When resizing the cache, we write to the channel, but @@ -164,13 +166,26 @@ func NewCuckooSentCache(cfg config.SampleCacheConfig, met metrics.Metrics) (Trac return nil, err } dropped := NewCuckooTraceChecker(cfg.DroppedSize, met) + // we want to keep track of the most recent dropped traces so we can avoid + // checking them in the dropped filter, which can have contention issues + // under high load. So we use a cache with TTL to keep track of the most + // recent dropped trace IDs, which lets us avoid checking the dropped filter + // for them for a short period of time. This means that when a whole batch + // of spans from the same trace arrives late, we don't have to check the + // dropped filter for each one. Benchmarks indicate that the Set cache is + // maybe 2-4x faster than the cuckoo filter and it also avoids lock + // contention issues in the cuckoo filter, so in practical use saves more + // than that. The TTL in this cache is short, because it's refreshed on each + // request. + recentDroppedIDs := generics.NewSetWithTTL[string](3 * time.Second) cache := &cuckooSentCache{ - kept: stc, - dropped: dropped, - cfg: cfg, - sentReasons: NewSentReasonsCache(met), - done: make(chan struct{}), + kept: stc, + dropped: dropped, + recentDroppedIDs: recentDroppedIDs, + cfg: cfg, + sentReasons: NewSentReasonsCache(met), + done: make(chan struct{}), } go cache.monitor() return cache, nil @@ -206,13 +221,21 @@ func (c *cuckooSentCache) Record(trace KeptTrace, keep bool, reason string) { return } - // if we're not keeping it, save it in the dropped trace filter + // if we're not keeping it, save it in the recentDroppedIDs cache + c.recentDroppedIDs.Add(trace.ID()) + // and also save it in the dropped trace filter c.dropped.Add(trace.ID()) } func (c *cuckooSentCache) CheckSpan(span *types.Span) (TraceSentRecord, string, bool) { - // was it dropped? + // was it recently dropped? + if c.recentDroppedIDs.Contains(span.TraceID) { + c.recentDroppedIDs.Add(span.TraceID) // refresh the TTL on this key + return &cuckooDroppedRecord{}, "", true + } + // was it in the drop cache? if c.dropped.Check(span.TraceID) { + c.recentDroppedIDs.Add(span.TraceID) // we recognize it as dropped, so just say so; there's nothing else to do return &cuckooDroppedRecord{}, "", true } diff --git a/collect/cache/cuckoo_test.go b/collect/cache/cuckoo_test.go index da3ecc20b5..881a715448 100644 --- a/collect/cache/cuckoo_test.go +++ b/collect/cache/cuckoo_test.go @@ -11,14 +11,15 @@ import ( ) // genID returns a random hex string of length numChars -func genID(numChars int) string { - seed := 3565269841805 +var seed = 3565269841805 +var rng = wyhash.Rng(seed) - const charset = "abcdef0123456789" +const charset = "abcdef0123456789" +func genID(numChars int) string { id := make([]byte, numChars) for i := 0; i < numChars; i++ { - id[i] = charset[int(wyhash.Rng(seed))%len(charset)] + id[i] = charset[int(rng.Next()%uint64(len(charset)))] } return string(id) } diff --git a/generics/setttl.go b/generics/setttl.go index 333e8bf63c..be4aa4d2e4 100644 --- a/generics/setttl.go +++ b/generics/setttl.go @@ -58,7 +58,7 @@ func (s *SetWithTTL[T]) Contains(e T) bool { if !ok { return false } - return item.After(time.Now()) + return item.After(s.Clock.Now()) } func (s *SetWithTTL[T]) cleanup() int { diff --git a/generics/setttl_test.go b/generics/setttl_test.go index 1fca1316d6..d86f805b40 100644 --- a/generics/setttl_test.go +++ b/generics/setttl_test.go @@ -4,10 +4,25 @@ import ( "testing" "time" + "github.com/dgryski/go-wyhash" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" ) +var seed = 3565269841805 +var rng = wyhash.Rng(seed) + +const charset = "abcdef0123456789" + +func genID(numChars int) string { + + id := make([]byte, numChars) + for i := 0; i < numChars; i++ { + id[i] = charset[int(rng.Next()%uint64(len(charset)))] + } + return string(id) +} + func TestSetTTLBasics(t *testing.T) { s := NewSetWithTTL(100*time.Millisecond, "a", "b", "b") fakeclock := clockwork.NewFakeClock() @@ -24,3 +39,52 @@ func TestSetTTLBasics(t *testing.T) { assert.Equal(t, 0, s.Length()) assert.Equal(t, s.Members(), []string{}) } + +func BenchmarkSetWithTTLContains(b *testing.B) { + s := NewSetWithTTL[string](10 * time.Second) + fc := clockwork.NewFakeClock() + s.Clock = fc + + n := 10000 + traceIDs := make([]string, n) + for i := 0; i < n; i++ { + traceIDs[i] = genID(32) + if i%2 == 0 { + s.Add(traceIDs[i]) + } + fc.Advance(1 * time.Microsecond) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Contains(traceIDs[i%n]) + } +} + +func BenchmarkSetWithTTLExpire(b *testing.B) { + s := NewSetWithTTL[string](1 * time.Second) + fc := clockwork.NewFakeClock() + s.Clock = fc + + // 1K ids created at 1ms intervals + // we'll check them over the course of 1 second as well, so they should all expire by the end + n := 1000 + traceIDs := make([]string, n) + for i := 0; i < n; i++ { + traceIDs[i] = genID(32) + s.Add(traceIDs[i]) + fc.Advance(1 * time.Millisecond) + } + // make sure we have 1000 ids now + assert.Equal(b, n, s.Length()) + b.ResetTimer() + advanceTime := 100 * time.Second / time.Duration(b.N) + for i := 0; i < b.N; i++ { + s.Contains(traceIDs[i%n]) + if i%100 == 0 { + fc.Advance(advanceTime) + } + } + b.StopTimer() + // make sure all ids have expired by now (there might be 1 or 2 that haven't) + assert.GreaterOrEqual(b, 2, s.Length()) +}