Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(kitmw): limit maximum concurrency #67

Merged
merged 3 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions kitmw/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ import (

// MakeAsyncMiddleware returns a go kit middleware that calls the next handler in
// a detached goroutine. Timeout and cancellation of the previous context no
// logger apply to the detached goroutine, the tracing context however is
// carried over.
func MakeAsyncMiddleware(logger log.Logger) endpoint.Middleware {
// logger apply to the detached goroutine, the tracing context however is carried
// over. A concurrency limit can be passed into the middleware. If the limit is
// reached, next endpoint call will block until the level of concurrency is below
// the limit.
func MakeAsyncMiddleware(logger log.Logger, concurrency int) endpoint.Middleware {
limit := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
limit <- struct{}{}
}
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
span := opentracing.SpanFromContext(ctx)
<-limit
go func() {
defer func() {
limit <- struct{}{}
}()
var err error
ctx := opentracing.ContextWithSpan(context.Background(), span)
_, err = next(ctx, request)
if err != nil {
Expand Down
53 changes: 53 additions & 0 deletions kitmw/async_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package kitmw

import (
"context"
"math/rand"
"testing"
"time"

"github.com/go-kit/kit/log"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
)

func TestMakeAsyncMiddleware(t *testing.T) {
var c atomic.Int32
m := MakeAsyncMiddleware(log.NewNopLogger(), 5)
f := m(func(ctx context.Context, request interface{}) (response interface{}, err error) {
c.Inc()
assert.Less(t, int(c.Load()), 5)
time.Sleep(time.Duration(rand.Float64()) * time.Second)
c.Dec()
return nil, nil
})

for i := 0; i < 10; i++ {
t.Run("", func(t *testing.T) {
t.Parallel()
f(context.Background(), nil)
})
}
}

func TestMakeAsyncMiddleware_tracing(t *testing.T) {
tracer := mocktracer.New()
span, ctx := opentracing.StartSpanFromContextWithTracer(context.Background(), tracer, "foo")
var done = make(chan struct{})

m := MakeAsyncMiddleware(log.NewNopLogger(), 5)
f := m(func(ctx context.Context, request interface{}) (response interface{}, err error) {
span := opentracing.SpanFromContext(ctx)
span.SetBaggageItem("foo", "bar")
done <- struct{}{}
return nil, nil
})

f(ctx, nil)
<-done
span.Finish()
assert.Equal(t, "bar", tracer.FinishedSpans()[0].BaggageItem("foo"))

}