From b09bd83773770c3ee1e027cad8239ced8532bb68 Mon Sep 17 00:00:00 2001 From: Eugene R Date: Wed, 10 Jan 2024 16:07:22 +0200 Subject: [PATCH] fix: throttler functionality; improve code coverage (#101) --- flow/throttler.go | 82 +++++++++++++++----------------- flow/throttler_test.go | 105 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 43 deletions(-) create mode 100644 flow/throttler_test.go diff --git a/flow/throttler.go b/flow/throttler.go index eff33ab..31d6144 100644 --- a/flow/throttler.go +++ b/flow/throttler.go @@ -1,20 +1,22 @@ package flow import ( + "fmt" "sync/atomic" "time" "github.com/reugn/go-streams" ) -// ThrottleMode defines the Throttler behavior on buffer overflow. +// ThrottleMode represents Throttler's processing behavior when its element +// buffer overflows. type ThrottleMode int8 const ( - // Backpressure on overflow mode. + // Backpressure slows down upstream ingestion when the element buffer overflows. Backpressure ThrottleMode = iota - // Discard elements on overflow mode. + // Discard drops incoming elements when the element buffer overflows. Discard ) @@ -25,7 +27,7 @@ type Throttler struct { mode ThrottleMode in chan interface{} out chan interface{} - notify chan struct{} + quotaSignal chan struct{} done chan struct{} counter uint64 } @@ -36,8 +38,8 @@ var _ streams.Flow = (*Throttler)(nil) // NewThrottler returns a new Throttler instance. // // elements is the maximum number of elements to be produced per the given period of time. -// bufferSize defines the incoming elements buffer size. -// mode defines the Throttler flow behavior on elements buffer overflow. +// bufferSize specifies the buffer size for incoming elements. +// mode specifies the processing behavior when the elements buffer overflows. func NewThrottler(elements uint, period time.Duration, bufferSize uint, mode ThrottleMode) *Throttler { throttler := &Throttler{ maxElements: uint64(elements), @@ -45,38 +47,31 @@ func NewThrottler(elements uint, period time.Duration, bufferSize uint, mode Thr mode: mode, in: make(chan interface{}), out: make(chan interface{}, bufferSize), - notify: make(chan struct{}), + quotaSignal: make(chan struct{}), done: make(chan struct{}), - counter: 0, } - go throttler.resetCounterLoop(period) + go throttler.resetQuotaCounterLoop() go throttler.bufferize() return throttler } -// incrementCounter increments the elements counter. -func (th *Throttler) incrementCounter() { - atomic.AddUint64(&th.counter, 1) -} - -// quotaHit verifies if the quota per time unit is exceeded. -func (th *Throttler) quotaHit() bool { +// quotaExceeded checks whether the quota per time unit has been exceeded. +func (th *Throttler) quotaExceeded() bool { return atomic.LoadUint64(&th.counter) >= th.maxElements } -// resetCounterLoop is the scheduled quota refresher. -func (th *Throttler) resetCounterLoop(after time.Duration) { - ticker := time.NewTicker(after) +// resetQuotaCounterLoop resets the throttler quota counter every th.period +// and sends a release notification to the downstream processor. +func (th *Throttler) resetQuotaCounterLoop() { + ticker := time.NewTicker(th.period) defer ticker.Stop() for { select { case <-ticker.C: - if th.quotaHit() { - atomic.StoreUint64(&th.counter, 0) - th.doNotify() - } + atomic.StoreUint64(&th.counter, 0) + th.notifyQuotaReset() // send quota reset case <-th.done: return @@ -84,46 +79,44 @@ func (th *Throttler) resetCounterLoop(after time.Duration) { } } -// doNotify notifies the producer goroutine with quota reset. -func (th *Throttler) doNotify() { +// notifyQuotaReset notifies the downstream processor with quota reset. +func (th *Throttler) notifyQuotaReset() { select { - case th.notify <- struct{}{}: + case th.quotaSignal <- struct{}{}: default: } } // bufferize starts buffering incoming elements. -// panics on an unsupported ThrottleMode. +// The method will panic if an unsupported ThrottleMode is specified. func (th *Throttler) bufferize() { switch th.mode { case Discard: - for e := range th.in { + for element := range th.in { select { - case th.out <- e: + case th.out <- element: default: } } case Backpressure: - for e := range th.in { - th.out <- e + for element := range th.in { + th.out <- element } default: - panic("Unsupported ThrottleMode") + panic(fmt.Sprintf("Unsupported ThrottleMode: %d", th.mode)) } - close(th.done) close(th.out) - close(th.notify) } // Via streams data through the given flow func (th *Throttler) Via(flow streams.Flow) streams.Flow { - go th.doStream(flow) + go th.streamPortioned(flow) return flow } // To streams data to the given sink func (th *Throttler) To(sink streams.Sink) { - th.doStream(sink) + th.streamPortioned(sink) } // Out returns an output channel for sending data @@ -136,14 +129,17 @@ func (th *Throttler) In() chan<- interface{} { return th.in } -// doStream streams data to the next Inlet. -func (th *Throttler) doStream(inlet streams.Inlet) { - for elem := range th.Out() { - if th.quotaHit() { - <-th.notify +// streamPortioned streams elements to the next Inlet. +// Subsequent processing of elements will be suspended when the quota limit is reached +// until the next quota reset event. +func (th *Throttler) streamPortioned(inlet streams.Inlet) { + for element := range th.out { + if th.quotaExceeded() { + <-th.quotaSignal // wait for quota reset } - th.incrementCounter() - inlet.In() <- elem + atomic.AddUint64(&th.counter, 1) + inlet.In() <- element } + close(th.done) close(inlet.In()) } diff --git a/flow/throttler_test.go b/flow/throttler_test.go new file mode 100644 index 0000000..3144649 --- /dev/null +++ b/flow/throttler_test.go @@ -0,0 +1,105 @@ +package flow_test + +import ( + "fmt" + "testing" + "time" + + ext "github.com/reugn/go-streams/extension" + "github.com/reugn/go-streams/flow" +) + +func TestThrottlerWithBackpressure(t *testing.T) { + in := make(chan interface{}) + out := make(chan interface{}) + + interval := 10 * time.Millisecond + source := ext.NewChanSource(in) + throttler := flow.NewThrottler(2, interval, 2, flow.Backpressure) + sink := ext.NewChanSink(out) + + go writeValues(in) + + go func() { + source. + Via(throttler). + To(sink) + }() + + outputValues := readValues(interval/2, out) + assertEquals(t, []interface{}{"a", "b"}, outputValues) + fmt.Println(outputValues) + + outputValues = readValues(interval, out) + fmt.Println(outputValues) + assertEquals(t, []interface{}{"c", "d"}, outputValues) + + outputValues = readValues(interval, out) + fmt.Println(outputValues) + assertEquals(t, []interface{}{"e", "f"}, outputValues) + + outputValues = readValues(interval, out) + fmt.Println(outputValues) + assertEquals(t, []interface{}{"g"}, outputValues) + + outputValues = readValues(interval, out) + fmt.Println(outputValues) + var empty []interface{} + assertEquals(t, empty, outputValues) +} + +func TestThrottlerWithDiscard(t *testing.T) { + in := make(chan interface{}, 7) + out := make(chan interface{}, 7) + + interval := 20 * time.Millisecond + source := ext.NewChanSource(in) + throttler := flow.NewThrottler(2, interval, 1, flow.Discard) + sink := ext.NewChanSink(out) + + go writeValues(in) + + go func() { + source. + Via(throttler). + To(sink) + }() + + outputValues := readValues(interval/2, out) + assertEquals(t, []interface{}{"a", "b"}, outputValues) + fmt.Println(outputValues) + + outputValues = readValues(interval, out) + fmt.Println(outputValues) + + outputValues = readValues(interval, out) + fmt.Println(outputValues) + var empty []interface{} + assertEquals(t, empty, outputValues) +} + +func writeValues(in chan interface{}) { + inputValues := []string{"a", "b", "c", "d", "e", "f", "g"} + ingestSlice(inputValues, in) + close(in) + fmt.Println("Closed input channel") +} + +func readValues(timeout time.Duration, out <-chan interface{}) []interface{} { + var outputValues []interface{} + timer := time.NewTimer(timeout) + for { + select { + case e := <-out: + if e != nil { + outputValues = append(outputValues, e) + } else { + fmt.Println("Got nil in output") + timer.Stop() + return outputValues + } + case <-timer.C: + return outputValues + } + } +}