From b596f44a0492a4312f3bb43d49066354550f2c28 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Sun, 16 Jun 2024 17:24:43 -0400 Subject: [PATCH 01/19] Add SetWithTTL --- generics/setttl.go | 75 +++++++++++++++++++++++++++++++++++++++++ generics/setttl_test.go | 25 ++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 generics/setttl.go create mode 100644 generics/setttl_test.go diff --git a/generics/setttl.go b/generics/setttl.go new file mode 100644 index 0000000000..5718112c09 --- /dev/null +++ b/generics/setttl.go @@ -0,0 +1,75 @@ +package generics + +import ( + "cmp" + "sort" + "time" + + "golang.org/x/exp/maps" +) + +// SetWithTTL is a unique set of items with a TTL (time to live) for each item. +// After the TTL expires, the item is automatically removed from the set. +type SetWithTTL[T cmp.Ordered] struct { + Items map[T]time.Time + TTL time.Duration +} + +// NewSetWithTTL returns a new SetWithTTL with elements `es` and a TTL of `ttl`. +func NewSetWithTTL[T cmp.Ordered](ttl time.Duration, es ...T) SetWithTTL[T] { + s := SetWithTTL[T]{ + Items: make(map[T]time.Time, len(es)), + TTL: ttl, + } + s.Add(es...) + return s +} + +// Add adds elements `es` to the SetWithTTL. +func (s SetWithTTL[T]) Add(es ...T) { + for _, e := range es { + s.Items[e] = time.Now().Add(s.TTL) + } +} + +// Remove removes elements `es` from the SetWithTTL. +func (s SetWithTTL[T]) Remove(es ...T) { + for _, e := range es { + delete(s.Items, e) + } +} + +// Contains returns true if the SetWithTTL contains `e`. +func (s SetWithTTL[T]) Contains(e T) bool { + item, ok := s.Items[e] + if !ok { + return false + } + return item.After(time.Now()) +} + +func (s SetWithTTL[T]) cleanup() { + maps.DeleteFunc(s.Items, func(k T, exp time.Time) bool { + return exp.Before(time.Now()) + }) +} + +// Members returns the unique elements of the SetWithTTL in sorted order. +// It also removes any items that have expired. +func (s SetWithTTL[T]) Members() []T { + s.cleanup() + members := make([]T, 0, len(s.Items)) + for member := range s.Items { + members = append(members, member) + } + sort.Slice(members, func(i, j int) bool { + return cmp.Less(members[i], members[j]) + }) + return members +} + +// Length returns the number of items in the SetWithTTL after removing any expired items. +func (s SetWithTTL[T]) Length() int { + s.cleanup() + return len(s.Items) +} diff --git a/generics/setttl_test.go b/generics/setttl_test.go new file mode 100644 index 0000000000..20177f2a3a --- /dev/null +++ b/generics/setttl_test.go @@ -0,0 +1,25 @@ +package generics + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSetTTLBasics(t *testing.T) { + s := NewSetWithTTL(100*time.Millisecond, "a", "b", "b") + assert.Equal(t, 2, s.Length()) + time.Sleep(50 * time.Millisecond) + s.Add("c") + assert.Equal(t, 3, s.Length()) + assert.Equal(t, s.Members(), []string{"a", "b", "c"}) + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + assert.Equal(collect, 1, s.Length()) + assert.Equal(collect, s.Members(), []string{"c"}) + }, 100*time.Millisecond, 20*time.Millisecond) + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + assert.Equal(collect, 0, s.Length()) + assert.Equal(collect, s.Members(), []string{}) + }, 100*time.Millisecond, 20*time.Millisecond) +} From 99d130d9cdd8cec43c6df931e1af173f606c05b7 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Sun, 16 Jun 2024 17:27:44 -0400 Subject: [PATCH 02/19] Import fanout from R3 --- generics/fanout.go | 237 ++++++++++++++++++++++++++++++++++++++++ generics/fanout_test.go | 196 +++++++++++++++++++++++++++++++++ 2 files changed, 433 insertions(+) create mode 100644 generics/fanout.go create mode 100644 generics/fanout_test.go diff --git a/generics/fanout.go b/generics/fanout.go new file mode 100644 index 0000000000..d0c552e8fc --- /dev/null +++ b/generics/fanout.go @@ -0,0 +1,237 @@ +package generics + +import "sync" + +// Fanout takes a slice of input, a parallelism factor, and a worker factory. It +// calls the generated worker on every element of the input, and returns a +// (possibly filtered) slice of the outputs in no particular order. Only the +// outputs that pass the predicate (if it is not nil) will be added to the +// slice. +// +// The factory takes an integer (the worker number) and constructs a function of +// type func(T) U that processes a single input and produces a single output. It +// also constructs a cleanup function, which may be nil. The cleanup function is +// called once for each worker, after the worker has completed processing all of +// its inputs. It is given the same index as the corresponding worker factory. +// +// If predicate is not nil, it will only add the output to the result slice if +// the predicate returns true. It will fan out the input to the worker function +// in parallel, and fan in the results to the output slice. +func Fanout[T, U any](input []T, parallelism int, workerFactory func(int) (worker func(T) U, cleanup func(int)), predicate func(U) bool) []U { + result := make([]U, 0) + + fanoutChan := make(chan T, parallelism) + faninChan := make(chan U, parallelism) + + // send all the trace IDs to the fanout channel + wgFans := sync.WaitGroup{} + wgFans.Add(1) + go func() { + defer wgFans.Done() + defer close(fanoutChan) + for i := range input { + fanoutChan <- input[i] + } + }() + + wgFans.Add(1) + go func() { + defer wgFans.Done() + for r := range faninChan { + result = append(result, r) + } + }() + + wgWorkers := sync.WaitGroup{} + for i := 0; i < parallelism; i++ { + wgWorkers.Add(1) + worker, cleanup := workerFactory(i) + go func(i int) { + defer wgWorkers.Done() + if cleanup != nil { + defer cleanup(i) + } + for u := range fanoutChan { + product := worker(u) + if predicate == nil || predicate(product) { + faninChan <- product + } + } + }(i) + } + + // wait for the workers to finish + wgWorkers.Wait() + // now we can close the fanin channel and wait for the fanin goroutine to finish + // fanout should already be done but this makes sure we don't lose track of it + close(faninChan) + wgFans.Wait() + + return result +} + +// EasyFanout is a convenience function for when you don't need all the +// features. It takes a slice of input, a parallelism factor, and a worker +// function. It calls the worker on every element of the input with the +// specified parallelism, and returns a slice of the outputs in no particular +// order. +func EasyFanout[T, U any](input []T, parallelism int, worker func(T) U) []U { + return Fanout(input, parallelism, func(int) (func(T) U, func(int)) { + return worker, nil + }, nil) +} + +// FanoutToMap takes a slice of input, a parallelism factor, and a worker +// factory. It calls the generated worker on every element of the input, and +// returns a (possibly filtered) map of the inputs to the outputs. Only the +// outputs that pass the predicate (if it is not nil) will be added to the map. +// +// The factory takes an integer (the worker number) and constructs a function of +// type func(T) U that processes a single input and produces a single output. It +// also constructs a cleanup function, which may be nil. The cleanup function is +// called once for each worker, after the worker has completed processing all of +// its inputs. It is given the same index as the corresponding worker factory. +// +// If predicate is not nil, it will only add the output to the result slice if +// the predicate returns true. It will fan out the input to the worker function +// in parallel, and fan in the results to the output slice. +func FanoutToMap[T comparable, U any](input []T, parallelism int, workerFactory func(int) (worker func(T) U, cleanup func(int)), predicate func(U) bool) map[T]U { + result := make(map[T]U) + type resultPair struct { + key T + val U + } + + fanoutChan := make(chan T, parallelism) + faninChan := make(chan resultPair, parallelism) + + // send all the trace IDs to the fanout channel + wgFans := sync.WaitGroup{} + wgFans.Add(1) + go func() { + defer wgFans.Done() + defer close(fanoutChan) + for i := range input { + fanoutChan <- input[i] + } + }() + + wgFans.Add(1) + go func() { + defer wgFans.Done() + for r := range faninChan { + result[r.key] = r.val + } + }() + + wgWorkers := sync.WaitGroup{} + for i := 0; i < parallelism; i++ { + wgWorkers.Add(1) + worker, cleanup := workerFactory(i) + go func(i int) { + defer wgWorkers.Done() + if cleanup != nil { + defer cleanup(i) + } + for t := range fanoutChan { + product := worker(t) + if predicate == nil || predicate(product) { + faninChan <- resultPair{t, product} + } + } + }(i) + } + + // wait for the workers to finish + wgWorkers.Wait() + // now we can close the fanin channel and wait for the fanin goroutine to finish + // fanout should already be done but this makes sure we don't lose track of it + close(faninChan) + wgFans.Wait() + + return result +} + +// EasyFanoutToMap is a convenience function for when you don't need all the +// features. It takes a slice of input, a parallelism factor, and a worker +// function. It calls the worker on every element of the input with the +// specified parallelism, and returns a map of the inputs to the outputs. +func EasyFanoutToMap[T comparable, U any](input []T, parallelism int, worker func(T) U) map[T]U { + return FanoutToMap(input, parallelism, func(int) (func(T) U, func(int)) { + return worker, nil + }, nil) +} + +// FanoutChunksToMap takes a slice of input, a chunk size, a maximum parallelism +// factor, and a worker factory. It calls the generated worker on every chunk of +// the input, and returns a (possibly filtered) map of the inputs to the +// outputs. Only the outputs that pass the predicate (if it is not nil) will be +// added to the map. +// +// The maximum parallelism factor is the maximum number of workers that will be +// run in parallel. The actual number of workers will be the minimum of the +// maximum parallelism factor and the number of chunks in the input. +func FanoutChunksToMap[T comparable, U any](input []T, chunkSize int, maxParallelism int, workerFactory func(int) (worker func([]T) map[T]U, cleanup func(int)), predicate func(U) bool) map[T]U { + result := make(map[T]U, 0) + + if chunkSize <= 0 { + chunkSize = 1 + } + + type resultPair struct { + key T + val U + } + parallelism := min(maxParallelism, max(len(input)/chunkSize, 1)) + fanoutChan := make(chan []T, parallelism) + faninChan := make(chan resultPair, parallelism) + + // send all the trace IDs to the fanout channel + wgFans := sync.WaitGroup{} + wgFans.Add(1) + go func() { + defer wgFans.Done() + defer close(fanoutChan) + for i := 0; i < len(input); i += chunkSize { + end := min(i+chunkSize, len(input)) + fanoutChan <- input[i:end] + } + }() + + wgFans.Add(1) + go func() { + defer wgFans.Done() + for r := range faninChan { + result[r.key] = r.val + } + }() + + wgWorkers := sync.WaitGroup{} + for i := 0; i < parallelism; i++ { + wgWorkers.Add(1) + worker, cleanup := workerFactory(i) + go func(i int) { + defer wgWorkers.Done() + if cleanup != nil { + defer cleanup(i) + } + for u := range fanoutChan { + products := worker(u) + for key, product := range products { + if predicate == nil || predicate(product) { + faninChan <- resultPair{key: key, val: product} + } + } + } + }(i) + } + + // wait for the workers to finish + wgWorkers.Wait() + // now we can close the fanin channel and wait for the fanin goroutine to finish + // fanout should already be done but this makes sure we don't lose track of it + close(faninChan) + wgFans.Wait() + + return result +} diff --git a/generics/fanout_test.go b/generics/fanout_test.go new file mode 100644 index 0000000000..a054ced37c --- /dev/null +++ b/generics/fanout_test.go @@ -0,0 +1,196 @@ +package generics + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestFanout(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + + result := Fanout(input, parallelism, workerFactory, nil) + assert.ElementsMatch(t, []int{2, 4, 6, 8, 10}, result) +} + +func TestFanoutWithPredicate(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + predicate := func(i int) bool { + return i%4 == 0 + } + + result := Fanout(input, parallelism, workerFactory, predicate) + assert.ElementsMatch(t, []int{4, 8}, result) +} + +func TestFanoutWithCleanup(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 4 + cleanupTotal := 0 + mut := sync.Mutex{} + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + cleanup := func(i int) { + mut.Lock() + cleanupTotal += i + mut.Unlock() + } + return worker, cleanup + } + + result := Fanout(input, parallelism, workerFactory, nil) + assert.ElementsMatch(t, []int{2, 4, 6, 8, 10}, result) + assert.Equal(t, 6, cleanupTotal) // 0 + 1 + 2 + 3 +} + +var expected = map[int]int{ + 1: 2, + 2: 4, + 3: 6, + 4: 8, + 5: 10, +} + +func TestFanoutMap(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + + result := FanoutToMap(input, parallelism, workerFactory, nil) + assert.EqualValues(t, expected, result) +} + +func TestFanoutMapWithPredicate(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + predicate := func(i int) bool { + return i%4 == 0 + } + + result := FanoutToMap(input, parallelism, workerFactory, predicate) + assert.EqualValues(t, map[int]int{2: 4, 4: 8}, result) +} + +func TestFanoutMapWithCleanup(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 4 + cleanupTotal := 0 + mut := sync.Mutex{} + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + cleanup := func(i int) { + mut.Lock() + cleanupTotal += i + mut.Unlock() + } + return worker, cleanup + } + + result := FanoutToMap(input, parallelism, workerFactory, nil) + assert.EqualValues(t, expected, result) + assert.Equal(t, 6, cleanupTotal) // 0 + 1 + 2 + 3 +} + +func TestEasyFanout(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + worker := func(i int) int { + return i * 2 + } + + result := EasyFanout(input, 3, worker) + assert.ElementsMatch(t, []int{2, 4, 6, 8, 10}, result) +} + +func TestEasyFanoutToMap(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + worker := func(i int) int { + return i * 2 + } + + result := EasyFanoutToMap(input, 3, worker) + assert.EqualValues(t, expected, result) +} + +func BenchmarkFanoutParallelism(b *testing.B) { + parallelisms := []int{1, 3, 6, 10, 25, 100} + for _, parallelism := range parallelisms { + b.Run(fmt.Sprintf("parallelism%02d", parallelism), func(b *testing.B) { + + input := make([]int, b.N) + for i := range input { + input[i] = i + } + + workerFactory := func(i int) (func(int) string, func(int)) { + worker := func(i int) string { + h := sha256.Sum256(([]byte(fmt.Sprintf("%d", i)))) + time.Sleep(1 * time.Millisecond) + return hex.EncodeToString(h[:]) + } + cleanup := func(i int) {} + return worker, cleanup + } + b.ResetTimer() + _ = Fanout(input, parallelism, workerFactory, nil) + }) + } +} + +func BenchmarkFanoutMapParallelism(b *testing.B) { + parallelisms := []int{1, 3, 6, 10, 25, 100} + for _, parallelism := range parallelisms { + b.Run(fmt.Sprintf("parallelism%02d", parallelism), func(b *testing.B) { + + input := make([]int, b.N) + for i := range input { + input[i] = i + } + + workerFactory := func(i int) (func(int) string, func(int)) { + worker := func(i int) string { + h := sha256.Sum256(([]byte(fmt.Sprintf("%d", i)))) + time.Sleep(1 * time.Millisecond) + return hex.EncodeToString(h[:]) + } + cleanup := func(i int) {} + return worker, cleanup + } + b.ResetTimer() + _ = FanoutToMap(input, parallelism, workerFactory, nil) + }) + } +} From 70ff10729850c0562304aee211281b733751d382 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Mon, 17 Jun 2024 15:32:51 -0400 Subject: [PATCH 03/19] make fanout test easier to reason about --- generics/fanout_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/generics/fanout_test.go b/generics/fanout_test.go index a054ced37c..988acba831 100644 --- a/generics/fanout_test.go +++ b/generics/fanout_test.go @@ -45,7 +45,7 @@ func TestFanoutWithPredicate(t *testing.T) { func TestFanoutWithCleanup(t *testing.T) { input := []int{1, 2, 3, 4, 5} parallelism := 4 - cleanupTotal := 0 + cleanups := []int{} mut := sync.Mutex{} workerFactory := func(i int) (func(int) int, func(int)) { worker := func(i int) int { @@ -53,7 +53,7 @@ func TestFanoutWithCleanup(t *testing.T) { } cleanup := func(i int) { mut.Lock() - cleanupTotal += i + cleanups = append(cleanups, i) mut.Unlock() } return worker, cleanup @@ -61,7 +61,7 @@ func TestFanoutWithCleanup(t *testing.T) { result := Fanout(input, parallelism, workerFactory, nil) assert.ElementsMatch(t, []int{2, 4, 6, 8, 10}, result) - assert.Equal(t, 6, cleanupTotal) // 0 + 1 + 2 + 3 + assert.ElementsMatch(t, []int{0, 1, 2, 3}, cleanups) } var expected = map[int]int{ @@ -106,7 +106,7 @@ func TestFanoutMapWithPredicate(t *testing.T) { func TestFanoutMapWithCleanup(t *testing.T) { input := []int{1, 2, 3, 4, 5} parallelism := 4 - cleanupTotal := 0 + cleanups := []int{} mut := sync.Mutex{} workerFactory := func(i int) (func(int) int, func(int)) { worker := func(i int) int { @@ -114,7 +114,7 @@ func TestFanoutMapWithCleanup(t *testing.T) { } cleanup := func(i int) { mut.Lock() - cleanupTotal += i + cleanups = append(cleanups, i) mut.Unlock() } return worker, cleanup @@ -122,7 +122,7 @@ func TestFanoutMapWithCleanup(t *testing.T) { result := FanoutToMap(input, parallelism, workerFactory, nil) assert.EqualValues(t, expected, result) - assert.Equal(t, 6, cleanupTotal) // 0 + 1 + 2 + 3 + assert.ElementsMatch(t, []int{0, 1, 2, 3}, cleanups) } func TestEasyFanout(t *testing.T) { From 48eba394e491ccbc6d28ea14e5e1ad810a993ba6 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Mon, 17 Jun 2024 16:27:39 -0400 Subject: [PATCH 04/19] SetWithTTL now has clockwork and is concurrency-safe --- generics/setttl.go | 40 ++++++++++++++++++++++++++++------------ generics/setttl_test.go | 19 ++++++++++--------- go.mod | 1 + go.sum | 2 ++ 4 files changed, 41 insertions(+), 21 deletions(-) diff --git a/generics/setttl.go b/generics/setttl.go index 5718112c09..1b00c18b39 100644 --- a/generics/setttl.go +++ b/generics/setttl.go @@ -3,65 +3,82 @@ package generics import ( "cmp" "sort" + "sync" "time" + "github.com/jonboulle/clockwork" "golang.org/x/exp/maps" ) // SetWithTTL is a unique set of items with a TTL (time to live) for each item. // After the TTL expires, the item is automatically removed from the set. +// It is safe for concurrent use. type SetWithTTL[T cmp.Ordered] struct { Items map[T]time.Time TTL time.Duration + Clock clockwork.Clock + mut sync.RWMutex } // NewSetWithTTL returns a new SetWithTTL with elements `es` and a TTL of `ttl`. -func NewSetWithTTL[T cmp.Ordered](ttl time.Duration, es ...T) SetWithTTL[T] { - s := SetWithTTL[T]{ +func NewSetWithTTL[T cmp.Ordered](ttl time.Duration, es ...T) *SetWithTTL[T] { + s := &SetWithTTL[T]{ Items: make(map[T]time.Time, len(es)), TTL: ttl, + Clock: clockwork.NewRealClock(), } s.Add(es...) return s } // Add adds elements `es` to the SetWithTTL. -func (s SetWithTTL[T]) Add(es ...T) { +func (s *SetWithTTL[T]) Add(es ...T) { + s.mut.Lock() + defer s.mut.Unlock() for _, e := range es { - s.Items[e] = time.Now().Add(s.TTL) + s.Items[e] = s.Clock.Now().Add(s.TTL) } } // Remove removes elements `es` from the SetWithTTL. -func (s SetWithTTL[T]) Remove(es ...T) { +func (s *SetWithTTL[T]) Remove(es ...T) { + s.mut.Lock() + defer s.mut.Unlock() for _, e := range es { delete(s.Items, e) } } // Contains returns true if the SetWithTTL contains `e`. -func (s SetWithTTL[T]) Contains(e T) bool { +func (s *SetWithTTL[T]) Contains(e T) bool { + s.mut.RLock() item, ok := s.Items[e] + s.mut.RUnlock() if !ok { return false } return item.After(time.Now()) } -func (s SetWithTTL[T]) cleanup() { +func (s *SetWithTTL[T]) cleanup() int { + s.mut.Lock() + defer s.mut.Unlock() maps.DeleteFunc(s.Items, func(k T, exp time.Time) bool { - return exp.Before(time.Now()) + return exp.Before(s.Clock.Now()) }) + return len(s.Items) } // Members returns the unique elements of the SetWithTTL in sorted order. // It also removes any items that have expired. -func (s SetWithTTL[T]) Members() []T { +func (s *SetWithTTL[T]) Members() []T { s.cleanup() members := make([]T, 0, len(s.Items)) + s.mut.RLock() for member := range s.Items { members = append(members, member) } + s.mut.RUnlock() sort.Slice(members, func(i, j int) bool { return cmp.Less(members[i], members[j]) }) @@ -69,7 +86,6 @@ func (s SetWithTTL[T]) Members() []T { } // Length returns the number of items in the SetWithTTL after removing any expired items. -func (s SetWithTTL[T]) Length() int { - s.cleanup() - return len(s.Items) +func (s *SetWithTTL[T]) Length() int { + return s.cleanup() } diff --git a/generics/setttl_test.go b/generics/setttl_test.go index 20177f2a3a..1fca1316d6 100644 --- a/generics/setttl_test.go +++ b/generics/setttl_test.go @@ -4,22 +4,23 @@ import ( "testing" "time" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" ) func TestSetTTLBasics(t *testing.T) { s := NewSetWithTTL(100*time.Millisecond, "a", "b", "b") + fakeclock := clockwork.NewFakeClock() + s.Clock = fakeclock assert.Equal(t, 2, s.Length()) - time.Sleep(50 * time.Millisecond) + fakeclock.Advance(50 * time.Millisecond) s.Add("c") assert.Equal(t, 3, s.Length()) assert.Equal(t, s.Members(), []string{"a", "b", "c"}) - assert.EventuallyWithT(t, func(collect *assert.CollectT) { - assert.Equal(collect, 1, s.Length()) - assert.Equal(collect, s.Members(), []string{"c"}) - }, 100*time.Millisecond, 20*time.Millisecond) - assert.EventuallyWithT(t, func(collect *assert.CollectT) { - assert.Equal(collect, 0, s.Length()) - assert.Equal(collect, s.Members(), []string{}) - }, 100*time.Millisecond, 20*time.Millisecond) + fakeclock.Advance(60 * time.Millisecond) + assert.Equal(t, 1, s.Length()) + assert.Equal(t, s.Members(), []string{"c"}) + fakeclock.Advance(100 * time.Millisecond) + assert.Equal(t, 0, s.Length()) + assert.Equal(t, s.Members(), []string{}) } diff --git a/go.mod b/go.mod index a75132594b..dfd01ecf5f 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/honeycombio/husky v0.30.0 github.com/honeycombio/libhoney-go v1.23.1 github.com/jessevdk/go-flags v1.5.0 + github.com/jonboulle/clockwork v0.4.0 github.com/json-iterator/go v1.1.12 github.com/klauspost/compress v1.17.8 github.com/panmari/cuckoofilter v1.0.3 diff --git a/go.sum b/go.sum index 775e438ae0..79d0a8677f 100644 --- a/go.sum +++ b/go.sum @@ -70,6 +70,8 @@ github.com/honeycombio/opentelemetry-proto-go/otlp v0.19.0-compat h1:fMpIzVAl5C2 github.com/honeycombio/opentelemetry-proto-go/otlp v0.19.0-compat/go.mod h1:mC2aK20Z/exugKpqCgcpwEadiS0im8K6mZsD4Is/hCY= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= +github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= +github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= From fc02f8976a639a776707a2d84faed060d37e27b9 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Mon, 17 Jun 2024 16:29:13 -0400 Subject: [PATCH 05/19] Use maps.Keys --- generics/setttl.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/generics/setttl.go b/generics/setttl.go index 1b00c18b39..f39d6a65a4 100644 --- a/generics/setttl.go +++ b/generics/setttl.go @@ -73,11 +73,8 @@ func (s *SetWithTTL[T]) cleanup() int { // It also removes any items that have expired. func (s *SetWithTTL[T]) Members() []T { s.cleanup() - members := make([]T, 0, len(s.Items)) s.mut.RLock() - for member := range s.Items { - members = append(members, member) - } + members := maps.Keys(s.Items) s.mut.RUnlock() sort.Slice(members, func(i, j int) bool { return cmp.Less(members[i], members[j]) From 67fc1a1573d9a2798e19c529d4d242d6000da0f8 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Sat, 15 Jun 2024 14:16:23 -0400 Subject: [PATCH 06/19] working, tested, but unused pubsub system --- go.mod | 2 + go.sum | 8 ++ pubsub/pubsub.go | 36 ++++++ pubsub/pubsub_goredis.go | 132 ++++++++++++++++++++ pubsub/pubsub_local.go | 116 +++++++++++++++++ pubsub/pubsub_test.go | 264 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 558 insertions(+) create mode 100644 pubsub/pubsub.go create mode 100644 pubsub/pubsub_goredis.go create mode 100644 pubsub/pubsub_local.go create mode 100644 pubsub/pubsub_test.go diff --git a/go.mod b/go.mod index 091b0ecb71..65f65fb447 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.19.1 github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 + github.com/redis/go-redis/v9 v9.5.3 github.com/sirupsen/logrus v1.9.3 github.com/sourcegraph/conc v0.3.0 github.com/stretchr/testify v1.9.0 @@ -51,6 +52,7 @@ require ( github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect github.com/facebookgo/limitgroup v0.0.0-20150612190941-6abd8d71ec01 // indirect github.com/facebookgo/muster v0.0.0-20150708232844-fd3d7953fd52 // indirect diff --git a/go.sum b/go.sum index 0261222db5..1408c7d7e1 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,10 @@ github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= @@ -17,6 +21,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165 h1:BS21ZUJ/B5X2UVUbczfmdWH7GapPWAhxcMsDnjJTU1E= github.com/dgryski/go-metro v0.0.0-20200812162917-85c65e2d0165/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-wyhash v0.0.0-20191203203029-c4841ae36371 h1:bz5ApY1kzFBvw3yckuyRBCtqGvprWrKswYK468nm+Gs= github.com/dgryski/go-wyhash v0.0.0-20191203203029-c4841ae36371/go.mod h1:/ENMIO1SQeJ5YQeUWWpbX8f+bS8INHrrhFjXgEqi4LA= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= @@ -104,6 +110,8 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.5.3 h1:fOAp1/uJG+ZtcITgZOfYFmTKPE7n4Vclj1wZFgRciUU= +github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go new file mode 100644 index 0000000000..96a39702af --- /dev/null +++ b/pubsub/pubsub.go @@ -0,0 +1,36 @@ +package pubsub + +import "context" + +// general usage: +// pubsub := pubsub.NewXXXPubSub() +// topic := pubsub.NewTopic(ctx, "name") +// topic.Publish(ctx, "message") +// ch := topic.Subscribe(ctx) +// for msg := range ch { +// fmt.Println(msg) +// } +// topic.Close() // optional if you want to close the topic independently +// pubsub.Close() + +type PubSub interface { + // NewTopic creates a new topic with the given name. + // When a topic is created, it is stored in the topics map and a goroutine is + // started to listen for messages on the topic; each message is sent to all + // subscribers to the topic. Close the topic to stop the goroutine and all subscriber + // channels. + NewTopic(ctx context.Context, topic string) Topic + // Close shuts down all topics and the pubsub connection. + Close() +} + +type Topic interface { + // Publish sends a message to all subscribers of the topic. + Publish(ctx context.Context, message string) error + // Subscribe returns a channel that will receive all messages published to the topic. + // There is no unsubscribe method; close the topic to stop receiving messages. + Subscribe(ctx context.Context) <-chan string + // Close shuts down the topic and all subscriber channels. Calling this is optional; + // the topic will be closed when the pubsub connection is closed. + Close() +} diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go new file mode 100644 index 0000000000..d74bd232eb --- /dev/null +++ b/pubsub/pubsub_goredis.go @@ -0,0 +1,132 @@ +package pubsub + +import ( + "context" + "sync" + + "github.com/redis/go-redis/v9" + "golang.org/x/exp/maps" +) + +// GoRedisPubSub is a PubSub implementation that uses Redis as the message broker +// and the go-redis library to interact with Redis. +type GoRedisPubSub struct { + rdb *redis.Client + topics map[string]*GoRedisTopic + mut sync.RWMutex +} + +type GoRedisTopic struct { + topic string + rdb *redis.Client // duplicating this avoids a lock + redisSub *redis.PubSub + subscribers []chan string + done chan struct{} + mut sync.RWMutex + closed bool +} + +func NewGoRedisPubSub() *GoRedisPubSub { + + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password set + DB: 0, // use default DB + }) + + return &GoRedisPubSub{ + rdb: rdb, + topics: make(map[string]*GoRedisTopic), + } +} + +// assert that GoRedisPubSub implements PubSub +var _ PubSub = (*GoRedisPubSub)(nil) + +// when a topic is created, it is stored in the topics map and a goroutine is +// started to listen for messages on the topic; each message is sent to all +// subscribers to the topic. +func (ps *GoRedisPubSub) NewTopic(ctx context.Context, topic string) Topic { + t := &GoRedisTopic{ + rdb: ps.rdb, + topic: topic, + redisSub: ps.rdb.Subscribe(ctx, topic), + subscribers: make([]chan string, 0), + done: make(chan struct{}), + closed: false, + } + + go func() { + for { + select { + case msg := <-t.redisSub.Channel(): + t.mut.RLock() + subscribers := t.subscribers + t.mut.RUnlock() + for _, sub := range subscribers { + select { + case sub <- msg.Payload: + case <-t.done: + return + } + } + case <-t.done: + t.redisSub.Close() + t.mut.RLock() + subscribers := t.subscribers + for _, sub := range subscribers { + close(sub) + } + t.subscribers = nil + t.mut.RUnlock() + return + } + } + }() + + ps.mut.Lock() + ps.topics[topic] = t + ps.mut.Unlock() + return t +} + +// Close shuts down all topics and the redis connection +func (ps *GoRedisPubSub) Close() { + ps.mut.Lock() + topics := maps.Values(ps.topics) + ps.mut.Unlock() + + for _, t := range topics { + t.Close() + } + ps.rdb.Close() +} + +// Publish sends a message to all subscribers of the topic +func (t *GoRedisTopic) Publish(ctx context.Context, message string) error { + err := t.rdb.Publish(ctx, t.topic, message).Err() + if err != nil { + return err + } + return nil +} + +// Subscribe returns a channel that will receive all messages published to the topic +func (t *GoRedisTopic) Subscribe(ctx context.Context) <-chan string { + ch := make(chan string) + t.mut.Lock() + t.subscribers = append(t.subscribers, ch) + t.mut.Unlock() + return ch +} + +// Close shuts down the topic and unsubscribes all subscribers +func (t *GoRedisTopic) Close() { + t.mut.Lock() + defer t.mut.Unlock() + if t.closed { + return + } + close(t.done) + t.closed = true +} diff --git a/pubsub/pubsub_local.go b/pubsub/pubsub_local.go new file mode 100644 index 0000000000..983ef674d1 --- /dev/null +++ b/pubsub/pubsub_local.go @@ -0,0 +1,116 @@ +package pubsub + +import ( + "context" + "sync" + + "golang.org/x/exp/maps" +) + +// LocalPubSub is a PubSub implementation that uses local channels to send messages; it does +// not communicate with any external processes. +type LocalPubSub struct { + topics map[string]*LocalTopic + mut sync.RWMutex +} + +type LocalTopic struct { + topic string + pubChan chan string + subscribers []chan string + done chan struct{} + mut sync.RWMutex + closed bool +} + +func NewLocalPubSub() *LocalPubSub { + return &LocalPubSub{ + topics: make(map[string]*LocalTopic), + } +} + +// assert that LocalPubSub implements PubSub +var _ PubSub = (*LocalPubSub)(nil) + +// when a topic is created, it is stored in the topics map and a goroutine is +// started to listen for messages on the topic; each message is sent to all +// subscribers to the topic. +func (ps *LocalPubSub) NewTopic(ctx context.Context, topic string) Topic { + t := &LocalTopic{ + topic: topic, + pubChan: make(chan string, 10), + subscribers: make([]chan string, 0), + done: make(chan struct{}), + closed: false, + } + + go func() { + for { + select { + case msg := <-t.pubChan: + t.mut.RLock() + subscribers := t.subscribers + t.mut.RUnlock() + for _, sub := range subscribers { + select { + case sub <- msg: + case <-t.done: + return + } + } + case <-t.done: + close(t.pubChan) + t.mut.RLock() + subscribers := t.subscribers + for _, sub := range subscribers { + close(sub) + } + t.subscribers = nil + t.mut.RUnlock() + return + } + } + }() + + ps.mut.Lock() + ps.topics[topic] = t + ps.mut.Unlock() + return t +} + +// Close shuts down all topics and the redis connection +func (ps *LocalPubSub) Close() { + ps.mut.Lock() + topics := maps.Values(ps.topics) + ps.mut.Unlock() + + for _, t := range topics { + t.Close() + } +} + +// Publish sends a message to all subscribers of the topic +func (t *LocalTopic) Publish(ctx context.Context, message string) error { + t.pubChan <- message + return nil +} + +// Subscribe returns a channel that will receive all messages published to the topic +func (t *LocalTopic) Subscribe(ctx context.Context) <-chan string { + ch := make(chan string) + t.mut.Lock() + t.subscribers = append(t.subscribers, ch) + t.mut.Unlock() + return ch +} + +// Close shuts down the topic and unsubscribes all subscribers +func (t *LocalTopic) Close() { + t.mut.Lock() + defer t.mut.Unlock() + if t.closed { + return + } + close(t.done) + t.closed = true +} diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go new file mode 100644 index 0000000000..23c51a9906 --- /dev/null +++ b/pubsub/pubsub_test.go @@ -0,0 +1,264 @@ +package pubsub_test + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/honeycombio/refinery/pubsub" + "github.com/stretchr/testify/require" +) + +var types = []string{"goredis", "local"} + +func newPubSub(typ string) pubsub.PubSub { + switch typ { + case "goredis": + return pubsub.NewGoRedisPubSub() + case "local": + return pubsub.NewLocalPubSub() + default: + panic("unknown pubsub type") + } +} + +func TestPubSubBasics(t *testing.T) { + ctx := context.Background() + for _, typ := range types { + t.Run(typ, func(t *testing.T) { + ps := newPubSub(typ) + topic := ps.NewTopic(ctx, "name") + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + time.Sleep(100 * time.Millisecond) + for i := 0; i < 10; i++ { + err := topic.Publish(ctx, fmt.Sprintf("message %d", i)) + require.NoError(t, err) + } + time.Sleep(100 * time.Millisecond) + topic.Close() + wg.Done() + }() + go func() { + ch := topic.Subscribe(ctx) + for msg := range ch { + require.Contains(t, msg, "message") + } + wg.Done() + }() + wg.Wait() + ps.Close() + }) + } +} + +func TestPubSubMultiTopic(t *testing.T) { + const topicCount = 100 + const messageCount = 10 + const expectedTotal = 55 // sum of 1 to messageCount + ctx := context.Background() + for _, typ := range types { + t.Run(typ, func(t *testing.T) { + ps := newPubSub(typ) + topics := make([]pubsub.Topic, topicCount) + for i := 0; i < topicCount; i++ { + topics[i] = ps.NewTopic(ctx, fmt.Sprintf("topic%d", i)) + } + time.Sleep(100 * time.Millisecond) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + for j := 0; j < topicCount; j++ { + for i := 0; i < messageCount; i++ { + // we want a different sum for each topic + err := topics[j].Publish(ctx, fmt.Sprintf("%d", (i+1)*(j+1))) + require.NoError(t, err) + } + } + time.Sleep(100 * time.Millisecond) + for i := 0; i < topicCount; i++ { + topics[i].Close() + } + wg.Done() + }() + mut := sync.Mutex{} + totals := make([]int, topicCount) + for i := 0; i < topicCount; i++ { + wg.Add(1) + go func(ix int) { + ch := topics[ix].Subscribe(ctx) + for msg := range ch { + n, _ := strconv.Atoi(msg) + mut.Lock() + totals[ix] += n + mut.Unlock() + } + wg.Done() + }(i) + } + wg.Wait() + ps.Close() + // validate that all the topics each add up to the desired total + for i := 0; i < topicCount; i++ { + require.Equal(t, expectedTotal*(i+1), totals[i]) + } + }) + } +} + +func TestPubSubLatency(t *testing.T) { + const messageCount = 1000 + ctx := context.Background() + for _, typ := range types { + t.Run(typ, func(t *testing.T) { + ps := newPubSub(typ) + topic := ps.NewTopic(ctx, "name") + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + for i := 0; i < messageCount; i++ { + err := topic.Publish(ctx, fmt.Sprintf("%d", time.Now().UnixNano())) + require.NoError(t, err) + } + time.Sleep(100 * time.Millisecond) + topic.Close() + wg.Done() + }() + var count, total, tmin, tmax int64 + go func() { + ch := topic.Subscribe(ctx) + for msg := range ch { + sent, err := strconv.Atoi(msg) + require.NoError(t, err) + rcvd := time.Now().UnixNano() + latency := rcvd - int64(sent) + require.True(t, latency >= 0) + total += latency + if tmin == 0 || latency < tmin { + tmin = latency + } + if latency > tmax { + tmax = latency + } + count++ + } + wg.Done() + }() + wg.Wait() + ps.Close() + require.Equal(t, int64(messageCount), count) + require.True(t, total > 0) + average := total / int64(count) + t.Logf("average: %d ns, min: %d ns, max: %d ns", average, tmin, tmax) + // in general, we want low latency, so we put some ballpark numbers here + // to make sure we're not doing something crazy + require.Less(t, average, int64(1*time.Millisecond)) + require.Less(t, tmax, int64(10*time.Millisecond)) + }) + } +} + +func BenchmarkPublish(b *testing.B) { + ctx := context.Background() + for _, typ := range types { + b.Run(typ, func(b *testing.B) { + ps := newPubSub(typ) + topic := ps.NewTopic(ctx, "name") + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := topic.Publish(ctx, "message") + require.NoError(b, err) + } + topic.Close() + ps.Close() + }) + } +} + +func BenchmarkPubSub(b *testing.B) { + ctx := context.Background() + for _, typ := range types { + b.Run(typ, func(b *testing.B) { + ps := newPubSub(typ) + topic := ps.NewTopic(ctx, "name") + time.Sleep(100 * time.Millisecond) + wg := sync.WaitGroup{} + wg.Add(2) + b.ResetTimer() + go func() { + for i := 0; i < b.N; i++ { + err := topic.Publish(ctx, fmt.Sprintf("message %d", i)) + require.NoError(b, err) + } + time.Sleep(1 * time.Millisecond) + topic.Close() + wg.Done() + }() + count := 0 + go func() { + ch := topic.Subscribe(ctx) + for range ch { + count++ + } + wg.Done() + }() + wg.Wait() + ps.Close() + require.Equal(b, b.N, count) + }) + } +} + +func BenchmarkPubSubMultiTopic(b *testing.B) { + const topicCount = 10 + ctx := context.Background() + for _, typ := range types { + b.Run(typ, func(b *testing.B) { + ps := newPubSub(typ) + topics := make([]pubsub.Topic, topicCount) + for i := 0; i < topicCount; i++ { + topics[i] = ps.NewTopic(ctx, fmt.Sprintf("topic%d", i)) + } + time.Sleep(100 * time.Millisecond) + wg := sync.WaitGroup{} + wg.Add(1) + b.ResetTimer() + go func() { + for i := 0; i < b.N; i++ { + err := topics[i%topicCount].Publish(ctx, fmt.Sprintf("message %d", i)) + require.NoError(b, err) + } + time.Sleep(1 * time.Millisecond) + for i := 0; i < topicCount; i++ { + topics[i].Close() + } + wg.Done() + }() + mut := sync.Mutex{} + counts := make([]int, topicCount) + for i := 0; i < topicCount; i++ { + wg.Add(1) + go func(ix int) { + ch := topics[ix].Subscribe(ctx) + for range ch { + mut.Lock() + counts[ix]++ + mut.Unlock() + } + wg.Done() + }(i) + } + wg.Wait() + ps.Close() + count := 0 + for i := 0; i < topicCount; i++ { + count += counts[i] + } + require.Equal(b, b.N, count) + }) + } +} From d602aaa35ac440935d46ca3e4d5bc1da93c9378d Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Sat, 15 Jun 2024 14:26:03 -0400 Subject: [PATCH 07/19] better limits for CI --- pubsub/pubsub_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 23c51a9906..249c3f0306 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -156,8 +156,8 @@ func TestPubSubLatency(t *testing.T) { t.Logf("average: %d ns, min: %d ns, max: %d ns", average, tmin, tmax) // in general, we want low latency, so we put some ballpark numbers here // to make sure we're not doing something crazy - require.Less(t, average, int64(1*time.Millisecond)) - require.Less(t, tmax, int64(10*time.Millisecond)) + require.Less(t, average, int64(100*time.Millisecond)) + require.Less(t, tmax, int64(500*time.Millisecond)) }) } } From ba009423f206c93e50979334204d158135a7fbfc Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Mon, 17 Jun 2024 11:16:25 -0400 Subject: [PATCH 08/19] Use sync.Once instead of closed flag --- pubsub/pubsub_goredis.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index d74bd232eb..b6eb4d2364 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -23,7 +23,7 @@ type GoRedisTopic struct { subscribers []chan string done chan struct{} mut sync.RWMutex - closed bool + once sync.Once } func NewGoRedisPubSub() *GoRedisPubSub { @@ -53,7 +53,6 @@ func (ps *GoRedisPubSub) NewTopic(ctx context.Context, topic string) Topic { redisSub: ps.rdb.Subscribe(ctx, topic), subscribers: make([]chan string, 0), done: make(chan struct{}), - closed: false, } go func() { @@ -124,9 +123,7 @@ func (t *GoRedisTopic) Subscribe(ctx context.Context) <-chan string { func (t *GoRedisTopic) Close() { t.mut.Lock() defer t.mut.Unlock() - if t.closed { - return - } - close(t.done) - t.closed = true + t.once.Do(func() { + close(t.done) + }) } From c62e4a8b3151815afc2e4331b2623a42265ef61a Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Mon, 17 Jun 2024 14:06:12 -0400 Subject: [PATCH 09/19] Further updates, add comments --- pubsub/pubsub_goredis.go | 10 ++++++++-- pubsub/pubsub_local.go | 13 ++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index b6eb4d2364..6bc82f9d65 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -8,6 +8,14 @@ import ( "golang.org/x/exp/maps" ) +// Notes for the future: we implemented a Redis-based PubSub system using 3 +// different libraries: go-redis, redigo, and rueidis. All three implementations +// perform similarly, but go-redis is definitely the easiest to use for PubSub. +// The rueidis library is probably the fastest for high-performance Redis use +// when you want Redis to be a database or cache, and it has some nice features +// like automatic pipelining, but it's pretty low-level and the documentation is +// poor. Redigo is feeling pretty old at this point. + // GoRedisPubSub is a PubSub implementation that uses Redis as the message broker // and the go-redis library to interact with Redis. type GoRedisPubSub struct { @@ -121,8 +129,6 @@ func (t *GoRedisTopic) Subscribe(ctx context.Context) <-chan string { // Close shuts down the topic and unsubscribes all subscribers func (t *GoRedisTopic) Close() { - t.mut.Lock() - defer t.mut.Unlock() t.once.Do(func() { close(t.done) }) diff --git a/pubsub/pubsub_local.go b/pubsub/pubsub_local.go index 983ef674d1..730ec3eb86 100644 --- a/pubsub/pubsub_local.go +++ b/pubsub/pubsub_local.go @@ -20,7 +20,7 @@ type LocalTopic struct { subscribers []chan string done chan struct{} mut sync.RWMutex - closed bool + once sync.Once } func NewLocalPubSub() *LocalPubSub { @@ -41,7 +41,6 @@ func (ps *LocalPubSub) NewTopic(ctx context.Context, topic string) Topic { pubChan: make(chan string, 10), subscribers: make([]chan string, 0), done: make(chan struct{}), - closed: false, } go func() { @@ -106,11 +105,7 @@ func (t *LocalTopic) Subscribe(ctx context.Context) <-chan string { // Close shuts down the topic and unsubscribes all subscribers func (t *LocalTopic) Close() { - t.mut.Lock() - defer t.mut.Unlock() - if t.closed { - return - } - close(t.done) - t.closed = true + t.once.Do(func() { + close(t.done) + }) } From c51b365b3112eb1d479ae3d6b62c8a2478364816 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Mon, 17 Jun 2024 21:42:50 -0400 Subject: [PATCH 10/19] Make them start/stoppers --- pubsub/pubsub.go | 11 ++++++++++- pubsub/pubsub_goredis.go | 17 +++++++++++------ pubsub/pubsub_local.go | 15 +++++++++++---- pubsub/pubsub_test.go | 7 +++++-- 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 96a39702af..f19f72de0a 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -1,6 +1,10 @@ package pubsub -import "context" +import ( + "context" + + "github.com/facebookgo/startstop" +) // general usage: // pubsub := pubsub.NewXXXPubSub() @@ -22,6 +26,11 @@ type PubSub interface { NewTopic(ctx context.Context, topic string) Topic // Close shuts down all topics and the pubsub connection. Close() + + // we want to embed startstop.Starter and startstop.Stopper so that we + // can participate in injection + startstop.Starter + startstop.Stopper } type Topic interface { diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index 6bc82f9d65..aa6e310ade 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "github.com/honeycombio/refinery/config" "github.com/redis/go-redis/v9" "golang.org/x/exp/maps" ) @@ -19,6 +20,7 @@ import ( // GoRedisPubSub is a PubSub implementation that uses Redis as the message broker // and the go-redis library to interact with Redis. type GoRedisPubSub struct { + Config *config.Config `inject:""` rdb *redis.Client topics map[string]*GoRedisTopic mut sync.RWMutex @@ -34,18 +36,21 @@ type GoRedisTopic struct { once sync.Once } -func NewGoRedisPubSub() *GoRedisPubSub { - +func (ps *GoRedisPubSub) Start() error { rdb := redis.NewClient(&redis.Options{ Addr: "localhost:6379", Password: "", // no password set DB: 0, // use default DB }) - return &GoRedisPubSub{ - rdb: rdb, - topics: make(map[string]*GoRedisTopic), - } + ps.rdb = rdb + ps.topics = make(map[string]*GoRedisTopic) + return nil +} + +func (ps *GoRedisPubSub) Stop() error { + ps.Close() + return nil } // assert that GoRedisPubSub implements PubSub diff --git a/pubsub/pubsub_local.go b/pubsub/pubsub_local.go index 730ec3eb86..ba01af7398 100644 --- a/pubsub/pubsub_local.go +++ b/pubsub/pubsub_local.go @@ -4,12 +4,14 @@ import ( "context" "sync" + "github.com/honeycombio/refinery/config" "golang.org/x/exp/maps" ) // LocalPubSub is a PubSub implementation that uses local channels to send messages; it does // not communicate with any external processes. type LocalPubSub struct { + Config *config.Config `inject:""` topics map[string]*LocalTopic mut sync.RWMutex } @@ -23,10 +25,14 @@ type LocalTopic struct { once sync.Once } -func NewLocalPubSub() *LocalPubSub { - return &LocalPubSub{ - topics: make(map[string]*LocalTopic), - } +func (ps *LocalPubSub) Start() error { + ps.topics = make(map[string]*LocalTopic) + return nil +} + +func (ps *LocalPubSub) Stop() error { + ps.Close() + return nil } // assert that LocalPubSub implements PubSub @@ -81,6 +87,7 @@ func (ps *LocalPubSub) NewTopic(ctx context.Context, topic string) Topic { func (ps *LocalPubSub) Close() { ps.mut.Lock() topics := maps.Values(ps.topics) + ps.topics = make(map[string]*LocalTopic) ps.mut.Unlock() for _, t := range topics { diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 249c3f0306..0ede5dfde6 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -15,14 +15,17 @@ import ( var types = []string{"goredis", "local"} func newPubSub(typ string) pubsub.PubSub { + var ps pubsub.PubSub switch typ { case "goredis": - return pubsub.NewGoRedisPubSub() + ps = &pubsub.GoRedisPubSub{} case "local": - return pubsub.NewLocalPubSub() + ps = &pubsub.LocalPubSub{} default: panic("unknown pubsub type") } + ps.Start() + return ps } func TestPubSubBasics(t *testing.T) { From 2790c6ee18d4d9fe64f0f2290148b8b6f411e7a3 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Tue, 18 Jun 2024 21:27:43 -0400 Subject: [PATCH 11/19] Set up pubsub with config-based parms --- pubsub/pubsub_goredis.go | 56 ++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index aa6e310ade..a3ccfab719 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -20,15 +20,15 @@ import ( // GoRedisPubSub is a PubSub implementation that uses Redis as the message broker // and the go-redis library to interact with Redis. type GoRedisPubSub struct { - Config *config.Config `inject:""` - rdb *redis.Client + Config config.Config `inject:""` + client *redis.Client topics map[string]*GoRedisTopic mut sync.RWMutex } type GoRedisTopic struct { topic string - rdb *redis.Client // duplicating this avoids a lock + client *redis.Client // duplicating this avoids a lock redisSub *redis.PubSub subscribers []chan string done chan struct{} @@ -37,13 +37,43 @@ type GoRedisTopic struct { } func (ps *GoRedisPubSub) Start() error { - rdb := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", // no password set - DB: 0, // use default DB - }) + options := &redis.Options{} + authcode := "" + + if ps.Config != nil { + host, err := ps.Config.GetRedisHost() + if err != nil { + return err + } + username, err := ps.Config.GetRedisUsername() + if err != nil { + return err + } + pw, err := ps.Config.GetRedisPassword() + if err != nil { + return err + } + + authcode, err = ps.Config.GetRedisAuthCode() + if err != nil { + return err + } + + options.Addr = host + options.Username = username + options.Password = pw + options.DB = ps.Config.GetRedisDatabase() + } + client := redis.NewClient(options) + + // if an authcode was provided, use it to authenticate the connection + if authcode != "" { + if err := client.Conn().Auth(context.Background(), authcode).Err(); err != nil { + return err + } + } - ps.rdb = rdb + ps.client = client ps.topics = make(map[string]*GoRedisTopic) return nil } @@ -61,9 +91,9 @@ var _ PubSub = (*GoRedisPubSub)(nil) // subscribers to the topic. func (ps *GoRedisPubSub) NewTopic(ctx context.Context, topic string) Topic { t := &GoRedisTopic{ - rdb: ps.rdb, + client: ps.client, topic: topic, - redisSub: ps.rdb.Subscribe(ctx, topic), + redisSub: ps.client.Subscribe(ctx, topic), subscribers: make([]chan string, 0), done: make(chan struct{}), } @@ -111,12 +141,12 @@ func (ps *GoRedisPubSub) Close() { for _, t := range topics { t.Close() } - ps.rdb.Close() + ps.client.Close() } // Publish sends a message to all subscribers of the topic func (t *GoRedisTopic) Publish(ctx context.Context, message string) error { - err := t.rdb.Publish(ctx, t.topic, message).Err() + err := t.client.Publish(ctx, t.topic, message).Err() if err != nil { return err } From 287cbd27cd459d47ebab71b2fe135c84b05cfa3a Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Tue, 18 Jun 2024 21:40:28 -0400 Subject: [PATCH 12/19] Wait a little longer for CI --- pubsub/pubsub_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 0ede5dfde6..e7fd9d31d1 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -127,7 +127,9 @@ func TestPubSubLatency(t *testing.T) { err := topic.Publish(ctx, fmt.Sprintf("%d", time.Now().UnixNano())) require.NoError(t, err) } - time.Sleep(100 * time.Millisecond) + // give the subscribers a chance to catch up + // before we close the topic + time.Sleep(500 * time.Millisecond) topic.Close() wg.Done() }() From 804777de2de9d2d54fc9eef08ee368a431cd4520 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Tue, 18 Jun 2024 21:51:00 -0400 Subject: [PATCH 13/19] Alternate approach to timing --- pubsub/pubsub_test.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index e7fd9d31d1..eaa8852994 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -120,6 +120,9 @@ func TestPubSubLatency(t *testing.T) { t.Run(typ, func(t *testing.T) { ps := newPubSub(typ) topic := ps.NewTopic(ctx, "name") + var count, total, tmin, tmax int64 + mut := sync.Mutex{} + wg := sync.WaitGroup{} wg.Add(2) go func() { @@ -127,13 +130,19 @@ func TestPubSubLatency(t *testing.T) { err := topic.Publish(ctx, fmt.Sprintf("%d", time.Now().UnixNano())) require.NoError(t, err) } - // give the subscribers a chance to catch up - // before we close the topic - time.Sleep(500 * time.Millisecond) + + // now wait for all messages to arrive + require.Eventually(t, func() bool { + mut.Lock() + done := count == messageCount + mut.Unlock() + return done + }, 5*time.Second, 100*time.Millisecond) + topic.Close() wg.Done() }() - var count, total, tmin, tmax int64 + go func() { ch := topic.Subscribe(ctx) for msg := range ch { @@ -149,7 +158,9 @@ func TestPubSubLatency(t *testing.T) { if latency > tmax { tmax = latency } + mut.Lock() count++ + mut.Unlock() } wg.Done() }() From 0fb1bc3e3c1be9bd043bbfb59dace33df77800aa Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Fri, 21 Jun 2024 14:15:54 -0400 Subject: [PATCH 14/19] Redesign the API (no more topics) --- pubsub/pubsub.go | 34 +++++----- pubsub/pubsub_goredis.go | 134 +++++++++++++++----------------------- pubsub/pubsub_local.go | 136 +++++++++++++++++---------------------- pubsub/pubsub_test.go | 127 ++++++++++++++++++------------------ 4 files changed, 190 insertions(+), 241 deletions(-) diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index f19f72de0a..e5b7672eea 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -8,22 +8,23 @@ import ( // general usage: // pubsub := pubsub.NewXXXPubSub() -// topic := pubsub.NewTopic(ctx, "name") -// topic.Publish(ctx, "message") -// ch := topic.Subscribe(ctx) -// for msg := range ch { +// pubsub.Start() +// defer pubsub.Stop() +// ctx := context.Background() +// pubsub.Publish(ctx, "topic", "message") +// sub := pubsub.Subscribe(ctx, "topic") +// for msg := range sub.Channel() { // fmt.Println(msg) // } -// topic.Close() // optional if you want to close the topic independently +// sub.Close() // optional // pubsub.Close() type PubSub interface { - // NewTopic creates a new topic with the given name. - // When a topic is created, it is stored in the topics map and a goroutine is - // started to listen for messages on the topic; each message is sent to all - // subscribers to the topic. Close the topic to stop the goroutine and all subscriber - // channels. - NewTopic(ctx context.Context, topic string) Topic + // Publish sends a message to all subscribers of the specified topic. + Publish(ctx context.Context, topic, message string) error + // Subscribe returns a Subscription that will receive all messages published to the specified topic. + // There is no unsubscribe method; close the subscription to stop receiving messages. + Subscribe(ctx context.Context, topic string) Subscription // Close shuts down all topics and the pubsub connection. Close() @@ -33,13 +34,10 @@ type PubSub interface { startstop.Stopper } -type Topic interface { - // Publish sends a message to all subscribers of the topic. - Publish(ctx context.Context, message string) error - // Subscribe returns a channel that will receive all messages published to the topic. - // There is no unsubscribe method; close the topic to stop receiving messages. - Subscribe(ctx context.Context) <-chan string - // Close shuts down the topic and all subscriber channels. Calling this is optional; +type Subscription interface { + // Channel returns the channel that will receive all messages published to the topic. + Channel() <-chan string + // Close stops the subscription and closes the channel. Calling this is optional; // the topic will be closed when the pubsub connection is closed. Close() } diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index a3ccfab719..952a4ddd99 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -6,7 +6,6 @@ import ( "github.com/honeycombio/refinery/config" "github.com/redis/go-redis/v9" - "golang.org/x/exp/maps" ) // Notes for the future: we implemented a Redis-based PubSub system using 3 @@ -22,20 +21,23 @@ import ( type GoRedisPubSub struct { Config config.Config `inject:""` client *redis.Client - topics map[string]*GoRedisTopic + subs []*GoRedisSubscription mut sync.RWMutex } -type GoRedisTopic struct { - topic string - client *redis.Client // duplicating this avoids a lock - redisSub *redis.PubSub - subscribers []chan string - done chan struct{} - mut sync.RWMutex - once sync.Once +// Ensure that GoRedisPubSub implements PubSub +var _ PubSub = (*GoRedisPubSub)(nil) + +type GoRedisSubscription struct { + topic string + pubsub *redis.PubSub + ch chan string + done chan struct{} } +// Ensure that GoRedisSubscription implements Subscription +var _ Subscription = (*GoRedisSubscription)(nil) + func (ps *GoRedisPubSub) Start() error { options := &redis.Options{} authcode := "" @@ -74,7 +76,7 @@ func (ps *GoRedisPubSub) Start() error { } ps.client = client - ps.topics = make(map[string]*GoRedisTopic) + ps.subs = make([]*GoRedisSubscription, 0) return nil } @@ -83,88 +85,58 @@ func (ps *GoRedisPubSub) Stop() error { return nil } -// assert that GoRedisPubSub implements PubSub -var _ PubSub = (*GoRedisPubSub)(nil) - -// when a topic is created, it is stored in the topics map and a goroutine is -// started to listen for messages on the topic; each message is sent to all -// subscribers to the topic. -func (ps *GoRedisPubSub) NewTopic(ctx context.Context, topic string) Topic { - t := &GoRedisTopic{ - client: ps.client, - topic: topic, - redisSub: ps.client.Subscribe(ctx, topic), - subscribers: make([]chan string, 0), - done: make(chan struct{}), +func (ps *GoRedisPubSub) Close() { + ps.mut.Lock() + defer ps.mut.Unlock() + for _, sub := range ps.subs { + sub.Close() } + ps.subs = nil + ps.client.Close() +} + +func (ps *GoRedisPubSub) Publish(ctx context.Context, topic, message string) error { + ps.mut.RLock() + defer ps.mut.RUnlock() + return ps.client.Publish(ctx, topic, message).Err() +} +func (ps *GoRedisPubSub) Subscribe(ctx context.Context, topic string) Subscription { + ps.mut.Lock() + defer ps.mut.Unlock() + sub := &GoRedisSubscription{ + topic: topic, + pubsub: ps.client.Subscribe(ctx, topic), + ch: make(chan string, 100), + done: make(chan struct{}), + } + ps.subs = append(ps.subs, sub) go func() { + redisch := sub.pubsub.Channel() for { select { - case msg := <-t.redisSub.Channel(): - t.mut.RLock() - subscribers := t.subscribers - t.mut.RUnlock() - for _, sub := range subscribers { - select { - case sub <- msg.Payload: - case <-t.done: - return - } + case <-sub.done: + close(sub.ch) + return + case msg := <-redisch: + if msg == nil { + continue } - case <-t.done: - t.redisSub.Close() - t.mut.RLock() - subscribers := t.subscribers - for _, sub := range subscribers { - close(sub) + select { + case sub.ch <- msg.Payload: + default: } - t.subscribers = nil - t.mut.RUnlock() - return } } }() - - ps.mut.Lock() - ps.topics[topic] = t - ps.mut.Unlock() - return t -} - -// Close shuts down all topics and the redis connection -func (ps *GoRedisPubSub) Close() { - ps.mut.Lock() - topics := maps.Values(ps.topics) - ps.mut.Unlock() - - for _, t := range topics { - t.Close() - } - ps.client.Close() -} - -// Publish sends a message to all subscribers of the topic -func (t *GoRedisTopic) Publish(ctx context.Context, message string) error { - err := t.client.Publish(ctx, t.topic, message).Err() - if err != nil { - return err - } - return nil + return sub } -// Subscribe returns a channel that will receive all messages published to the topic -func (t *GoRedisTopic) Subscribe(ctx context.Context) <-chan string { - ch := make(chan string) - t.mut.Lock() - t.subscribers = append(t.subscribers, ch) - t.mut.Unlock() - return ch +func (s *GoRedisSubscription) Channel() <-chan string { + return s.ch } -// Close shuts down the topic and unsubscribes all subscribers -func (t *GoRedisTopic) Close() { - t.once.Do(func() { - close(t.done) - }) +func (s *GoRedisSubscription) Close() { + s.pubsub.Close() + close(s.done) } diff --git a/pubsub/pubsub_local.go b/pubsub/pubsub_local.go index ba01af7398..a2ff3f09f7 100644 --- a/pubsub/pubsub_local.go +++ b/pubsub/pubsub_local.go @@ -5,114 +5,98 @@ import ( "sync" "github.com/honeycombio/refinery/config" - "golang.org/x/exp/maps" ) // LocalPubSub is a PubSub implementation that uses local channels to send messages; it does // not communicate with any external processes. type LocalPubSub struct { Config *config.Config `inject:""` - topics map[string]*LocalTopic + subs []*LocalSubscription + topics map[string]chan string mut sync.RWMutex } -type LocalTopic struct { - topic string - pubChan chan string - subscribers []chan string - done chan struct{} - mut sync.RWMutex - once sync.Once +// Ensure that LocalPubSub implements PubSub +var _ PubSub = (*LocalPubSub)(nil) + +type LocalSubscription struct { + topic string + ch chan string + done chan struct{} } +// Ensure that LocalSubscription implements Subscription +var _ Subscription = (*LocalSubscription)(nil) + +// Start initializes the LocalPubSub func (ps *LocalPubSub) Start() error { - ps.topics = make(map[string]*LocalTopic) + ps.subs = make([]*LocalSubscription, 0) + ps.topics = make(map[string]chan string) return nil } +// Stop shuts down the LocalPubSub func (ps *LocalPubSub) Stop() error { ps.Close() return nil } -// assert that LocalPubSub implements PubSub -var _ PubSub = (*LocalPubSub)(nil) +func (ps *LocalPubSub) Close() { + ps.mut.Lock() + defer ps.mut.Unlock() + for _, sub := range ps.subs { + sub.Close() + } + ps.subs = nil +} + +func (ps *LocalPubSub) ensureTopic(topic string) chan string { + if _, ok := ps.topics[topic]; !ok { + ps.topics[topic] = make(chan string, 100) + } + return ps.topics[topic] +} -// when a topic is created, it is stored in the topics map and a goroutine is -// started to listen for messages on the topic; each message is sent to all -// subscribers to the topic. -func (ps *LocalPubSub) NewTopic(ctx context.Context, topic string) Topic { - t := &LocalTopic{ - topic: topic, - pubChan: make(chan string, 10), - subscribers: make([]chan string, 0), - done: make(chan struct{}), +func (ps *LocalPubSub) Publish(ctx context.Context, topic, message string) error { + ps.mut.RLock() + ch := ps.ensureTopic(topic) + ps.mut.RUnlock() + select { + case ch <- message: + case <-ctx.Done(): + return ctx.Err() } + return nil +} +func (ps *LocalPubSub) Subscribe(ctx context.Context, topic string) Subscription { + ps.mut.Lock() + defer ps.mut.Unlock() + ch := ps.ensureTopic(topic) + sub := &LocalSubscription{ + topic: topic, + ch: ch, + done: make(chan struct{}), + } + ps.subs = append(ps.subs, sub) go func() { for { select { - case msg := <-t.pubChan: - t.mut.RLock() - subscribers := t.subscribers - t.mut.RUnlock() - for _, sub := range subscribers { - select { - case sub <- msg: - case <-t.done: - return - } - } - case <-t.done: - close(t.pubChan) - t.mut.RLock() - subscribers := t.subscribers - for _, sub := range subscribers { - close(sub) - } - t.subscribers = nil - t.mut.RUnlock() + case <-sub.done: + close(ch) return + case msg := <-ch: + sub.ch <- msg } } }() - - ps.mut.Lock() - ps.topics[topic] = t - ps.mut.Unlock() - return t -} - -// Close shuts down all topics and the redis connection -func (ps *LocalPubSub) Close() { - ps.mut.Lock() - topics := maps.Values(ps.topics) - ps.topics = make(map[string]*LocalTopic) - ps.mut.Unlock() - - for _, t := range topics { - t.Close() - } -} - -// Publish sends a message to all subscribers of the topic -func (t *LocalTopic) Publish(ctx context.Context, message string) error { - t.pubChan <- message - return nil + return sub } -// Subscribe returns a channel that will receive all messages published to the topic -func (t *LocalTopic) Subscribe(ctx context.Context) <-chan string { - ch := make(chan string) - t.mut.Lock() - t.subscribers = append(t.subscribers, ch) - t.mut.Unlock() - return ch +func (s *LocalSubscription) Channel() <-chan string { + return s.ch } -// Close shuts down the topic and unsubscribes all subscribers -func (t *LocalTopic) Close() { - t.once.Do(func() { - close(t.done) - }) +func (s *LocalSubscription) Close() { + close(s.done) } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index eaa8852994..42f4a32221 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -12,7 +12,10 @@ import ( "github.com/stretchr/testify/require" ) -var types = []string{"goredis", "local"} +var types = []string{ + "goredis", + "local", +} func newPubSub(typ string) pubsub.PubSub { var ps pubsub.PubSub @@ -33,24 +36,24 @@ func TestPubSubBasics(t *testing.T) { for _, typ := range types { t.Run(typ, func(t *testing.T) { ps := newPubSub(typ) - topic := ps.NewTopic(ctx, "name") wg := sync.WaitGroup{} wg.Add(2) go func() { - time.Sleep(100 * time.Millisecond) - for i := 0; i < 10; i++ { - err := topic.Publish(ctx, fmt.Sprintf("message %d", i)) - require.NoError(t, err) + sub := ps.Subscribe(ctx, "topic") + ch := sub.Channel() + for msg := range ch { + require.Contains(t, msg, "message") } - time.Sleep(100 * time.Millisecond) - topic.Close() wg.Done() }() go func() { - ch := topic.Subscribe(ctx) - for msg := range ch { - require.Contains(t, msg, "message") + time.Sleep(100 * time.Millisecond) + for i := 0; i < 10; i++ { + err := ps.Publish(ctx, "topic", fmt.Sprintf("message %d", i)) + require.NoError(t, err) } + time.Sleep(100 * time.Millisecond) + ps.Close() wg.Done() }() wg.Wait() @@ -67,34 +70,34 @@ func TestPubSubMultiTopic(t *testing.T) { for _, typ := range types { t.Run(typ, func(t *testing.T) { ps := newPubSub(typ) - topics := make([]pubsub.Topic, topicCount) + topics := make([]string, topicCount) for i := 0; i < topicCount; i++ { - topics[i] = ps.NewTopic(ctx, fmt.Sprintf("topic%d", i)) + topics[i] = fmt.Sprintf("topic%d", i) } time.Sleep(100 * time.Millisecond) wg := sync.WaitGroup{} wg.Add(1) go func() { + time.Sleep(100 * time.Millisecond) for j := 0; j < topicCount; j++ { for i := 0; i < messageCount; i++ { // we want a different sum for each topic - err := topics[j].Publish(ctx, fmt.Sprintf("%d", (i+1)*(j+1))) + err := ps.Publish(ctx, topics[j], fmt.Sprintf("%d", (i+1)*(j+1))) require.NoError(t, err) } } time.Sleep(100 * time.Millisecond) - for i := 0; i < topicCount; i++ { - topics[i].Close() - } + ps.Close() wg.Done() }() mut := sync.Mutex{} totals := make([]int, topicCount) + subs := make([]pubsub.Subscription, topicCount) for i := 0; i < topicCount; i++ { wg.Add(1) go func(ix int) { - ch := topics[ix].Subscribe(ctx) - for msg := range ch { + subs[ix] = ps.Subscribe(ctx, topics[ix]) + for msg := range subs[ix].Channel() { n, _ := strconv.Atoi(msg) mut.Lock() totals[ix] += n @@ -119,7 +122,6 @@ func TestPubSubLatency(t *testing.T) { for _, typ := range types { t.Run(typ, func(t *testing.T) { ps := newPubSub(typ) - topic := ps.NewTopic(ctx, "name") var count, total, tmin, tmax int64 mut := sync.Mutex{} @@ -127,7 +129,7 @@ func TestPubSubLatency(t *testing.T) { wg.Add(2) go func() { for i := 0; i < messageCount; i++ { - err := topic.Publish(ctx, fmt.Sprintf("%d", time.Now().UnixNano())) + err := ps.Publish(ctx, "topic", fmt.Sprintf("%d", time.Now().UnixNano())) require.NoError(t, err) } @@ -139,13 +141,13 @@ func TestPubSubLatency(t *testing.T) { return done }, 5*time.Second, 100*time.Millisecond) - topic.Close() + ps.Close() wg.Done() }() go func() { - ch := topic.Subscribe(ctx) - for msg := range ch { + sub := ps.Subscribe(ctx, "topic") + for msg := range sub.Channel() { sent, err := strconv.Atoi(msg) require.NoError(t, err) rcvd := time.Now().UnixNano() @@ -178,47 +180,39 @@ func TestPubSubLatency(t *testing.T) { } } -func BenchmarkPublish(b *testing.B) { - ctx := context.Background() - for _, typ := range types { - b.Run(typ, func(b *testing.B) { - ps := newPubSub(typ) - topic := ps.NewTopic(ctx, "name") - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := topic.Publish(ctx, "message") - require.NoError(b, err) - } - topic.Close() - ps.Close() - }) - } -} - func BenchmarkPubSub(b *testing.B) { ctx := context.Background() for _, typ := range types { b.Run(typ, func(b *testing.B) { ps := newPubSub(typ) - topic := ps.NewTopic(ctx, "name") time.Sleep(100 * time.Millisecond) + count := 0 + mut := sync.Mutex{} + wg := sync.WaitGroup{} wg.Add(2) b.ResetTimer() go func() { + time.Sleep(100 * time.Millisecond) for i := 0; i < b.N; i++ { - err := topic.Publish(ctx, fmt.Sprintf("message %d", i)) + err := ps.Publish(ctx, "topic", fmt.Sprintf("message %d", i)) require.NoError(b, err) } - time.Sleep(1 * time.Millisecond) - topic.Close() + require.Eventually(b, func() bool { + mut.Lock() + defer mut.Unlock() + return count == b.N + }, 5*time.Second, 10*time.Millisecond) + ps.Close() wg.Done() }() - count := 0 + go func() { - ch := topic.Subscribe(ctx) - for range ch { + sub := ps.Subscribe(ctx, "topic") + for range sub.Channel() { + mut.Lock() count++ + mut.Unlock() } wg.Done() }() @@ -231,37 +225,43 @@ func BenchmarkPubSub(b *testing.B) { func BenchmarkPubSubMultiTopic(b *testing.B) { const topicCount = 10 - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() for _, typ := range types { b.Run(typ, func(b *testing.B) { ps := newPubSub(typ) - topics := make([]pubsub.Topic, topicCount) + topics := make([]string, topicCount) for i := 0; i < topicCount; i++ { - topics[i] = ps.NewTopic(ctx, fmt.Sprintf("topic%d", i)) + topics[i] = fmt.Sprintf("topic%d", i) } - time.Sleep(100 * time.Millisecond) + mut := sync.RWMutex{} + count := 0 + counts := make([]int, topicCount) + wg := sync.WaitGroup{} wg.Add(1) b.ResetTimer() go func() { + time.Sleep(100 * time.Millisecond) for i := 0; i < b.N; i++ { - err := topics[i%topicCount].Publish(ctx, fmt.Sprintf("message %d", i)) + err := ps.Publish(ctx, topics[i%topicCount], fmt.Sprintf("message %d", i)) require.NoError(b, err) } - time.Sleep(1 * time.Millisecond) - for i := 0; i < topicCount; i++ { - topics[i].Close() - } + require.Eventually(b, func() bool { + mut.RLock() + defer mut.RUnlock() + return count == b.N + }, 1*time.Second, 100*time.Millisecond) + ps.Close() wg.Done() }() - mut := sync.Mutex{} - counts := make([]int, topicCount) for i := 0; i < topicCount; i++ { wg.Add(1) go func(ix int) { - ch := topics[ix].Subscribe(ctx) - for range ch { + sub := ps.Subscribe(ctx, topics[ix]) + for range sub.Channel() { mut.Lock() + count++ counts[ix]++ mut.Unlock() } @@ -270,11 +270,6 @@ func BenchmarkPubSubMultiTopic(b *testing.B) { } wg.Wait() ps.Close() - count := 0 - for i := 0; i < topicCount; i++ { - count += counts[i] - } - require.Equal(b, b.N, count) }) } } From 3f89fb03264c94daeb1f0b464b9306b364748932 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Fri, 21 Jun 2024 14:19:34 -0400 Subject: [PATCH 15/19] Wait a bit before we send --- pubsub/pubsub_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 42f4a32221..e092d6b89a 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -78,7 +78,7 @@ func TestPubSubMultiTopic(t *testing.T) { wg := sync.WaitGroup{} wg.Add(1) go func() { - time.Sleep(100 * time.Millisecond) + time.Sleep(500 * time.Millisecond) for j := 0; j < topicCount; j++ { for i := 0; i < messageCount; i++ { // we want a different sum for each topic From a4bbe7218416b9718219a74c0986a53570ec234b Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Fri, 21 Jun 2024 15:39:34 -0400 Subject: [PATCH 16/19] Respond to feedback --- pubsub/pubsub_goredis.go | 10 ++++++++-- pubsub/pubsub_test.go | 7 ------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index 952a4ddd99..7a4db93584 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/honeycombio/refinery/config" + "github.com/honeycombio/refinery/logger" "github.com/redis/go-redis/v9" ) @@ -20,6 +21,7 @@ import ( // and the go-redis library to interact with Redis. type GoRedisPubSub struct { Config config.Config `inject:""` + Logger logger.Logger `inject:""` client *redis.Client subs []*GoRedisSubscription mut sync.RWMutex @@ -33,6 +35,7 @@ type GoRedisSubscription struct { pubsub *redis.PubSub ch chan string done chan struct{} + once sync.Once } // Ensure that GoRedisSubscription implements Subscription @@ -125,6 +128,7 @@ func (ps *GoRedisPubSub) Subscribe(ctx context.Context, topic string) Subscripti select { case sub.ch <- msg.Payload: default: + ps.Logger.Warn().WithField("topic", topic).Logf("Dropping subscription message because channel is full") } } } @@ -137,6 +141,8 @@ func (s *GoRedisSubscription) Channel() <-chan string { } func (s *GoRedisSubscription) Close() { - s.pubsub.Close() - close(s.done) + s.once.Do(func() { + s.pubsub.Close() + close(s.done) + }) } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index e092d6b89a..19c9e476fe 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -57,7 +57,6 @@ func TestPubSubBasics(t *testing.T) { wg.Done() }() wg.Wait() - ps.Close() }) } } @@ -107,7 +106,6 @@ func TestPubSubMultiTopic(t *testing.T) { }(i) } wg.Wait() - ps.Close() // validate that all the topics each add up to the desired total for i := 0; i < topicCount; i++ { require.Equal(t, expectedTotal*(i+1), totals[i]) @@ -167,7 +165,6 @@ func TestPubSubLatency(t *testing.T) { wg.Done() }() wg.Wait() - ps.Close() require.Equal(t, int64(messageCount), count) require.True(t, total > 0) average := total / int64(count) @@ -217,7 +214,6 @@ func BenchmarkPubSub(b *testing.B) { wg.Done() }() wg.Wait() - ps.Close() require.Equal(b, b.N, count) }) } @@ -236,7 +232,6 @@ func BenchmarkPubSubMultiTopic(b *testing.B) { } mut := sync.RWMutex{} count := 0 - counts := make([]int, topicCount) wg := sync.WaitGroup{} wg.Add(1) @@ -262,14 +257,12 @@ func BenchmarkPubSubMultiTopic(b *testing.B) { for range sub.Channel() { mut.Lock() count++ - counts[ix]++ mut.Unlock() } wg.Done() }(i) } wg.Wait() - ps.Close() }) } } From 2e424cd443cbfb3189dc9ab233fba63b9c734201 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Fri, 21 Jun 2024 15:43:17 -0400 Subject: [PATCH 17/19] Gah, CI slowness --- pubsub/pubsub_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 19c9e476fe..b7bf040b3b 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -126,6 +126,7 @@ func TestPubSubLatency(t *testing.T) { wg := sync.WaitGroup{} wg.Add(2) go func() { + time.Sleep(300 * time.Millisecond) for i := 0; i < messageCount; i++ { err := ps.Publish(ctx, "topic", fmt.Sprintf("%d", time.Now().UnixNano())) require.NoError(t, err) From c696875b9f0c0417650ba0172ea03aeee942bc9e Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Fri, 21 Jun 2024 15:51:29 -0400 Subject: [PATCH 18/19] redis client is concurrency-safe --- pubsub/pubsub_goredis.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index 7a4db93584..fe7b9d373e 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -90,30 +90,28 @@ func (ps *GoRedisPubSub) Stop() error { func (ps *GoRedisPubSub) Close() { ps.mut.Lock() - defer ps.mut.Unlock() for _, sub := range ps.subs { sub.Close() } ps.subs = nil + ps.mut.Unlock() ps.client.Close() } func (ps *GoRedisPubSub) Publish(ctx context.Context, topic, message string) error { - ps.mut.RLock() - defer ps.mut.RUnlock() return ps.client.Publish(ctx, topic, message).Err() } func (ps *GoRedisPubSub) Subscribe(ctx context.Context, topic string) Subscription { - ps.mut.Lock() - defer ps.mut.Unlock() sub := &GoRedisSubscription{ topic: topic, pubsub: ps.client.Subscribe(ctx, topic), ch: make(chan string, 100), done: make(chan struct{}), } + ps.mut.Lock() ps.subs = append(ps.subs, sub) + ps.mut.Unlock() go func() { redisch := sub.pubsub.Channel() for { From b0d09e64c546292248854eea5ca16df9a21a1a45 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Fri, 21 Jun 2024 16:07:56 -0400 Subject: [PATCH 19/19] Switch to universal client and support cluster. --- pubsub/pubsub_goredis.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pubsub/pubsub_goredis.go b/pubsub/pubsub_goredis.go index fe7b9d373e..94b23542d5 100644 --- a/pubsub/pubsub_goredis.go +++ b/pubsub/pubsub_goredis.go @@ -2,6 +2,7 @@ package pubsub import ( "context" + "strings" "sync" "github.com/honeycombio/refinery/config" @@ -22,7 +23,7 @@ import ( type GoRedisPubSub struct { Config config.Config `inject:""` Logger logger.Logger `inject:""` - client *redis.Client + client redis.UniversalClient subs []*GoRedisSubscription mut sync.RWMutex } @@ -42,7 +43,7 @@ type GoRedisSubscription struct { var _ Subscription = (*GoRedisSubscription)(nil) func (ps *GoRedisPubSub) Start() error { - options := &redis.Options{} + options := &redis.UniversalOptions{} authcode := "" if ps.Config != nil { @@ -64,16 +65,22 @@ func (ps *GoRedisPubSub) Start() error { return err } - options.Addr = host + // we may have multiple hosts, separated by commas, so split them up and + // use them as the addrs for the client (if there are multiples, it will + // create a cluster client) + hosts := strings.Split(host, ",") + options.Addrs = hosts options.Username = username options.Password = pw options.DB = ps.Config.GetRedisDatabase() } - client := redis.NewClient(options) + client := redis.NewUniversalClient(options) // if an authcode was provided, use it to authenticate the connection if authcode != "" { - if err := client.Conn().Auth(context.Background(), authcode).Err(); err != nil { + pipe := client.Pipeline() + pipe.Auth(context.Background(), authcode) + if _, err := pipe.Exec(context.Background()); err != nil { return err } }