From a0cbf5592ee4437af255c75766ca07f67c8dbeee Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Wed, 23 Mar 2022 08:07:50 +0100 Subject: [PATCH] Fix pre-routing middlewares by using chi's management (#64) --- chirouter/wrapper.go | 21 +++--- chirouter/wrapper_test.go | 151 +++++++++++++++++++++++++++++++++++--- nethttp/wrap.go | 8 ++ web/service_test.go | 3 +- 4 files changed, 160 insertions(+), 23 deletions(-) diff --git a/chirouter/wrapper.go b/chirouter/wrapper.go index 8ba56ff..66194de 100644 --- a/chirouter/wrapper.go +++ b/chirouter/wrapper.go @@ -39,13 +39,14 @@ func (r *Wrapper) copy(router chi.Router, pattern string) *Wrapper { // Use appends one of more middlewares onto the Router stack. func (r *Wrapper) Use(middlewares ...func(http.Handler) http.Handler) { + r.Router.Use(middlewares...) r.middlewares = append(r.middlewares, middlewares...) } // With adds inline middlewares for an endpoint handler. func (r Wrapper) With(middlewares ...func(http.Handler) http.Handler) chi.Router { - c := r.copy(r.Router, "") - c.Use(middlewares...) + c := r.copy(r.Router.With(middlewares...), "") + c.middlewares = append(c.middlewares, middlewares...) return c } @@ -76,18 +77,20 @@ func (r *Wrapper) Route(pattern string, fn func(r chi.Router)) chi.Router { // Mount attaches another http.Handler along "./basePattern/*". func (r *Wrapper) Mount(pattern string, h http.Handler) { - p := r.prepareHandler("", pattern, h) - r.Router.Mount(pattern, p) + r.captureHandler("", pattern, h) + r.Router.Mount(pattern, h) } // Handle adds routes for `basePattern` that matches all HTTP methods. func (r *Wrapper) Handle(pattern string, h http.Handler) { - r.Router.Handle(pattern, r.prepareHandler("", pattern, h)) + r.captureHandler("", pattern, h) + r.Router.Handle(pattern, h) } // Method adds routes for `basePattern` that matches the `method` HTTP method. func (r *Wrapper) Method(method, pattern string, h http.Handler) { - r.Router.Method(method, pattern, r.prepareHandler(method, pattern, h)) + r.captureHandler(method, pattern, h) + r.Router.Method(method, pattern, h) } // MethodFunc adds the route `pattern` that matches `method` http method to execute the `handlerFn` http.HandlerFunc. @@ -144,10 +147,8 @@ func (r *Wrapper) resolvePattern(pattern string) string { return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/") } -func (r *Wrapper) prepareHandler(method, pattern string, h http.Handler) http.Handler { +func (r *Wrapper) captureHandler(method, pattern string, h http.Handler) { mw := r.middlewares mw = append(mw, nethttp.HandlerWithRouteMiddleware(method, r.resolvePattern(pattern))) - h = nethttp.WrapHandler(h, mw...) - - return h + nethttp.WrapHandler(h, mw...) } diff --git a/chirouter/wrapper_test.go b/chirouter/wrapper_test.go index c2867fb..e4f02a7 100644 --- a/chirouter/wrapper_test.go +++ b/chirouter/wrapper_test.go @@ -7,8 +7,10 @@ import ( "testing" "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/swaggest/rest" "github.com/swaggest/rest/chirouter" "github.com/swaggest/rest/nethttp" ) @@ -23,28 +25,47 @@ type HandlerWithBar struct { http.Handler } +func (h HandlerWithFoo) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if _, err := rw.Write([]byte("foo")); err != nil { + panic(err) + } + + h.Handler.ServeHTTP(rw, r) +} + func (h HandlerWithBar) Bar() {} func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + h.Handler.ServeHTTP(rw, r) + if _, err := rw.Write([]byte("bar")); err != nil { panic(err) } - - h.Handler.ServeHTTP(rw, r) } func TestNewWrapper(t *testing.T) { - var r chi.Router - r = chi.NewRouter() - - r = chirouter.NewWrapper(r).With(func(handler http.Handler) http.Handler { + r := chirouter.NewWrapper(chi.NewRouter()).With(func(handler http.Handler) http.Handler { return http.HandlerFunc(handler.ServeHTTP) }) + handlersCnt := 0 + totalCnt := 0 + mw := func(handler http.Handler) http.Handler { - var bar interface{ Bar() } + var ( + withRoute rest.HandlerWithRoute + bar interface{ Bar() } + foo interface{ Foo() } + ) - assert.False(t, nethttp.HandlerAs(handler, &bar)) + totalCnt++ + + if nethttp.HandlerAs(handler, &withRoute) { + handlersCnt++ + + assert.False(t, nethttp.HandlerAs(handler, &bar), "%s", handler) + assert.True(t, nethttp.HandlerAs(handler, &foo), "%s", handler) + } return HandlerWithBar{Handler: handler} } @@ -62,11 +83,13 @@ func TestNewWrapper(t *testing.T) { ) }) - r.Mount("/mount", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})) + r.Mount("/mount", + HandlerWithFoo{Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})}, + ) - r.Route("/deeper/", func(r chi.Router) { + r.Route("/deeper", func(r chi.Router) { r.Use(func(handler http.Handler) http.Handler { - return handler + return HandlerWithFoo{Handler: handler} }) r.Get("/foo", func(writer http.ResponseWriter, request *http.Request) {}) @@ -91,6 +114,110 @@ func TestNewWrapper(t *testing.T) { rw := httptest.NewRecorder() r.ServeHTTP(rw, req) - assert.Equal(t, "bar", rw.Body.String(), u) + assert.Equal(t, "foobar", rw.Body.String(), u) } + + assert.Equal(t, 13, handlersCnt) + assert.Equal(t, 20, totalCnt) +} + +func TestWrapper_Use_precedence(t *testing.T) { + var log []string + + // Vanilla chi router. + cr := chi.NewRouter() + cr.Use( + func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + log = append(log, "cmw1 before") + handler.ServeHTTP(writer, request) + log = append(log, "cmw1 after") + }) + }, + + func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + log = append(log, "cmw2 before") + handler.ServeHTTP(writer, request) + log = append(log, "cmw2 after") + }) + }, + ) + + // Wrapped chi router. + wr := chirouter.NewWrapper(chi.NewRouter()) + wr.Use( + func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + log = append(log, "wmw1 before") + handler.ServeHTTP(writer, request) + log = append(log, "wmw1 after") + }) + }, + + func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + log = append(log, "wmw2 before") + handler.ServeHTTP(writer, request) + log = append(log, "wmw2 after") + }) + }, + ) + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + log = append(log, "h") + }) + + // Both routers should invoke middlewares in the same order. + cr.Method(http.MethodGet, "/", h) + wr.Method(http.MethodGet, "/", h) + + cr.ServeHTTP(nil, req) + wr.ServeHTTP(nil, req) + assert.Equal(t, []string{ + "cmw1 before", "cmw2 before", "h", "cmw2 after", "cmw1 after", + "wmw1 before", "wmw2 before", "h", "wmw2 after", "wmw1 after", + }, log) +} + +func TestWrapper_Use_StripSlashes(t *testing.T) { + var log []string + + // Wrapped chi router. + wr := chirouter.NewWrapper(chi.NewRouter()) + wr.Use( + middleware.StripSlashes, + + func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + handler.ServeHTTP(writer, request) + }) + }, + ) + + req, err := http.NewRequest(http.MethodGet, "/foo/", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + + h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if _, err := writer.Write([]byte("OK")); err != nil { + log = append(log, err.Error()) + } + + log = append(log, "h") + }) + + wr.Method(http.MethodGet, "/foo", h) + wr.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, "OK", rw.Body.String()) + + assert.Equal(t, []string{ + "h", + }, log) } diff --git a/nethttp/wrap.go b/nethttp/wrap.go index 082c0d0..9772260 100644 --- a/nethttp/wrap.go +++ b/nethttp/wrap.go @@ -88,3 +88,11 @@ type wrappedHandler struct { wrapped http.Handler mwName string } + +func (w *wrappedHandler) String() string { + if h, ok := w.wrapped.(*wrappedHandler); ok { + return w.mwName + "(" + h.String() + ")" + } + + return "handler" +} diff --git a/web/service_test.go b/web/service_test.go index 64d8456..a7c7abb 100644 --- a/web/service_test.go +++ b/web/service_test.go @@ -28,7 +28,8 @@ func TestDefaultService(t *testing.T) { service.Trace("/albums", postAlbums(), nethttp.SuccessStatus(http.StatusCreated)) service.Options("/albums", postAlbums(), nethttp.SuccessStatus(http.StatusCreated)) service.Docs("/docs", func(title, schemaURL, basePath string) http.Handler { - return nil + // Mount github.com/swaggest/swgui/v4emb.New here. + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}) }) rw := httptest.NewRecorder()