diff --git a/middleware/middleware.go b/middleware/middleware.go index 0bce584b0..d69a41c32 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -42,9 +42,10 @@ func ChainMiddlewares(m ...Middleware) Middleware { return next(req) } } + tail := ChainMiddlewares(m[1:]...) return func(req Request, next Next) (Response, error) { return m[0](req, func(req Request) (Response, error) { - return ChainMiddlewares(m[1:]...)(req, next) + return tail(req, next) }) } } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 6d1793388..e4dee0dda 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -36,17 +36,66 @@ func TestChainMiddlewares(t *testing.T) { ) a := require.New(t) - req := Request{ - Context: context.Background(), - Body: []string{}, - Params: map[string]any{}, + for i := range [2]struct{}{} { + req := Request{ + Context: context.Background(), + Body: []string{}, + Params: map[string]any{ + "call": i, + }, + } + resp, err := chain(req, func(req Request) (Response, error) { + a.Equal([]string{"first", "second", "third"}, req.Body) + a.Equal("bar", req.Params["second"]) + a.Equal("baz", req.Context.Value(testKey{})) + a.Equal(i, req.Params["call"]) + return Response{Type: "ok"}, nil + }) + a.NoError(err) + a.Equal("ok", resp.Type) + } +} + +func BenchmarkChainMiddlewares(b *testing.B) { + const N = 20 + noop := func(req Request, next Next) (Response, error) { + return next(req) + } + + var ( + chain = ChainMiddlewares(func() (r []Middleware) { + for i := 0; i < N; i++ { + r = append(r, noop) + } + return r + }()...) + req = Request{ + Context: context.Background(), + Body: []string{}, + Params: map[string]any{}, + } + resp = Response{Type: "ok"} + next = func(req Request) (Response, error) { + return resp, nil + } + ) + + b.ReportAllocs() + b.ResetTimer() + + var ( + sinkResp Response + sinkErr error + ) + + for i := 0; i < b.N; i++ { + sinkResp, sinkErr = chain(req, next) + } + + if sinkErr != nil { + b.Fatal(sinkErr) + } + if sinkResp != resp { + b.Fatalf("Expected %v, got %v", resp, sinkResp) } - resp, err := chain(req, func(req Request) (Response, error) { - a.Equal([]string{"first", "second", "third"}, req.Body) - a.Equal("bar", req.Params["second"]) - a.Equal("baz", req.Context.Value(testKey{})) - return Response{Type: "ok"}, nil - }) - a.NoError(err) - a.Equal("ok", resp.Type) }