diff --git a/kitmw/async.go b/kitmw/async.go index 95db24e8..9ac625ec 100644 --- a/kitmw/async.go +++ b/kitmw/async.go @@ -11,13 +11,24 @@ import ( // MakeAsyncMiddleware returns a go kit middleware that calls the next handler in // a detached goroutine. Timeout and cancellation of the previous context no -// logger apply to the detached goroutine, the tracing context however is -// carried over. -func MakeAsyncMiddleware(logger log.Logger) endpoint.Middleware { +// logger apply to the detached goroutine, the tracing context however is carried +// over. A concurrency limit can be passed into the middleware. If the limit is +// reached, next endpoint call will block until the level of concurrency is below +// the limit. +func MakeAsyncMiddleware(logger log.Logger, concurrency int) endpoint.Middleware { + limit := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + limit <- struct{}{} + } return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { span := opentracing.SpanFromContext(ctx) + <-limit go func() { + defer func() { + limit <- struct{}{} + }() + var err error ctx := opentracing.ContextWithSpan(context.Background(), span) _, err = next(ctx, request) if err != nil { diff --git a/kitmw/async_test.go b/kitmw/async_test.go new file mode 100644 index 00000000..de244e08 --- /dev/null +++ b/kitmw/async_test.go @@ -0,0 +1,53 @@ +package kitmw + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/go-kit/kit/log" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/mocktracer" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestMakeAsyncMiddleware(t *testing.T) { + var c atomic.Int32 + m := MakeAsyncMiddleware(log.NewNopLogger(), 5) + f := m(func(ctx context.Context, request interface{}) (response interface{}, err error) { + c.Inc() + assert.Less(t, int(c.Load()), 5) + time.Sleep(time.Duration(rand.Float64()) * time.Second) + c.Dec() + return nil, nil + }) + + for i := 0; i < 10; i++ { + t.Run("", func(t *testing.T) { + t.Parallel() + f(context.Background(), nil) + }) + } +} + +func TestMakeAsyncMiddleware_tracing(t *testing.T) { + tracer := mocktracer.New() + span, ctx := opentracing.StartSpanFromContextWithTracer(context.Background(), tracer, "foo") + var done = make(chan struct{}) + + m := MakeAsyncMiddleware(log.NewNopLogger(), 5) + f := m(func(ctx context.Context, request interface{}) (response interface{}, err error) { + span := opentracing.SpanFromContext(ctx) + span.SetBaggageItem("foo", "bar") + done <- struct{}{} + return nil, nil + }) + + f(ctx, nil) + <-done + span.Finish() + assert.Equal(t, "bar", tracer.FinishedSpans()[0].BaggageItem("foo")) + +}