Skip to content

Commit

Permalink
Fix pre-routing middlewares by using chi's management (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
vearutop committed Mar 23, 2022
1 parent 23b1a36 commit a0cbf55
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 23 deletions.
21 changes: 11 additions & 10 deletions chirouter/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ func (r *Wrapper) copy(router chi.Router, pattern string) *Wrapper {

// Use appends one of more middlewares onto the Router stack.
func (r *Wrapper) Use(middlewares ...func(http.Handler) http.Handler) {
r.Router.Use(middlewares...)
r.middlewares = append(r.middlewares, middlewares...)
}

// With adds inline middlewares for an endpoint handler.
func (r Wrapper) With(middlewares ...func(http.Handler) http.Handler) chi.Router {
c := r.copy(r.Router, "")
c.Use(middlewares...)
c := r.copy(r.Router.With(middlewares...), "")
c.middlewares = append(c.middlewares, middlewares...)

return c
}
Expand Down Expand Up @@ -76,18 +77,20 @@ func (r *Wrapper) Route(pattern string, fn func(r chi.Router)) chi.Router {

// Mount attaches another http.Handler along "./basePattern/*".
func (r *Wrapper) Mount(pattern string, h http.Handler) {
p := r.prepareHandler("", pattern, h)
r.Router.Mount(pattern, p)
r.captureHandler("", pattern, h)
r.Router.Mount(pattern, h)
}

// Handle adds routes for `basePattern` that matches all HTTP methods.
func (r *Wrapper) Handle(pattern string, h http.Handler) {
r.Router.Handle(pattern, r.prepareHandler("", pattern, h))
r.captureHandler("", pattern, h)
r.Router.Handle(pattern, h)
}

// Method adds routes for `basePattern` that matches the `method` HTTP method.
func (r *Wrapper) Method(method, pattern string, h http.Handler) {
r.Router.Method(method, pattern, r.prepareHandler(method, pattern, h))
r.captureHandler(method, pattern, h)
r.Router.Method(method, pattern, h)
}

// MethodFunc adds the route `pattern` that matches `method` http method to execute the `handlerFn` http.HandlerFunc.
Expand Down Expand Up @@ -144,10 +147,8 @@ func (r *Wrapper) resolvePattern(pattern string) string {
return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/")
}

func (r *Wrapper) prepareHandler(method, pattern string, h http.Handler) http.Handler {
func (r *Wrapper) captureHandler(method, pattern string, h http.Handler) {
mw := r.middlewares
mw = append(mw, nethttp.HandlerWithRouteMiddleware(method, r.resolvePattern(pattern)))
h = nethttp.WrapHandler(h, mw...)

return h
nethttp.WrapHandler(h, mw...)
}
151 changes: 139 additions & 12 deletions chirouter/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"testing"

"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/swaggest/rest"
"github.com/swaggest/rest/chirouter"
"github.com/swaggest/rest/nethttp"
)
Expand All @@ -23,28 +25,47 @@ type HandlerWithBar struct {
http.Handler
}

func (h HandlerWithFoo) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if _, err := rw.Write([]byte("foo")); err != nil {
panic(err)
}

h.Handler.ServeHTTP(rw, r)
}

func (h HandlerWithBar) Bar() {}

func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
h.Handler.ServeHTTP(rw, r)

if _, err := rw.Write([]byte("bar")); err != nil {
panic(err)
}

h.Handler.ServeHTTP(rw, r)
}

func TestNewWrapper(t *testing.T) {
var r chi.Router
r = chi.NewRouter()

r = chirouter.NewWrapper(r).With(func(handler http.Handler) http.Handler {
r := chirouter.NewWrapper(chi.NewRouter()).With(func(handler http.Handler) http.Handler {
return http.HandlerFunc(handler.ServeHTTP)
})

handlersCnt := 0
totalCnt := 0

mw := func(handler http.Handler) http.Handler {
var bar interface{ Bar() }
var (
withRoute rest.HandlerWithRoute
bar interface{ Bar() }
foo interface{ Foo() }
)

assert.False(t, nethttp.HandlerAs(handler, &bar))
totalCnt++

if nethttp.HandlerAs(handler, &withRoute) {
handlersCnt++

assert.False(t, nethttp.HandlerAs(handler, &bar), "%s", handler)
assert.True(t, nethttp.HandlerAs(handler, &foo), "%s", handler)
}

return HandlerWithBar{Handler: handler}
}
Expand All @@ -62,11 +83,13 @@ func TestNewWrapper(t *testing.T) {
)
})

r.Mount("/mount", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}))
r.Mount("/mount",
HandlerWithFoo{Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})},
)

r.Route("/deeper/", func(r chi.Router) {
r.Route("/deeper", func(r chi.Router) {
r.Use(func(handler http.Handler) http.Handler {
return handler
return HandlerWithFoo{Handler: handler}
})

r.Get("/foo", func(writer http.ResponseWriter, request *http.Request) {})
Expand All @@ -91,6 +114,110 @@ func TestNewWrapper(t *testing.T) {
rw := httptest.NewRecorder()
r.ServeHTTP(rw, req)

assert.Equal(t, "bar", rw.Body.String(), u)
assert.Equal(t, "foobar", rw.Body.String(), u)
}

assert.Equal(t, 13, handlersCnt)
assert.Equal(t, 20, totalCnt)
}

func TestWrapper_Use_precedence(t *testing.T) {
var log []string

// Vanilla chi router.
cr := chi.NewRouter()
cr.Use(
func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
log = append(log, "cmw1 before")
handler.ServeHTTP(writer, request)
log = append(log, "cmw1 after")
})
},

func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
log = append(log, "cmw2 before")
handler.ServeHTTP(writer, request)
log = append(log, "cmw2 after")
})
},
)

// Wrapped chi router.
wr := chirouter.NewWrapper(chi.NewRouter())
wr.Use(
func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
log = append(log, "wmw1 before")
handler.ServeHTTP(writer, request)
log = append(log, "wmw1 after")
})
},

func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
log = append(log, "wmw2 before")
handler.ServeHTTP(writer, request)
log = append(log, "wmw2 after")
})
},
)

req, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)

h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
log = append(log, "h")
})

// Both routers should invoke middlewares in the same order.
cr.Method(http.MethodGet, "/", h)
wr.Method(http.MethodGet, "/", h)

cr.ServeHTTP(nil, req)
wr.ServeHTTP(nil, req)
assert.Equal(t, []string{
"cmw1 before", "cmw2 before", "h", "cmw2 after", "cmw1 after",
"wmw1 before", "wmw2 before", "h", "wmw2 after", "wmw1 after",
}, log)
}

func TestWrapper_Use_StripSlashes(t *testing.T) {
var log []string

// Wrapped chi router.
wr := chirouter.NewWrapper(chi.NewRouter())
wr.Use(
middleware.StripSlashes,

func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
handler.ServeHTTP(writer, request)
})
},
)

req, err := http.NewRequest(http.MethodGet, "/foo/", nil)
require.NoError(t, err)

rw := httptest.NewRecorder()

h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if _, err := writer.Write([]byte("OK")); err != nil {
log = append(log, err.Error())
}

log = append(log, "h")
})

wr.Method(http.MethodGet, "/foo", h)
wr.ServeHTTP(rw, req)

assert.Equal(t, http.StatusOK, rw.Code)
assert.Equal(t, "OK", rw.Body.String())

assert.Equal(t, []string{
"h",
}, log)
}
8 changes: 8 additions & 0 deletions nethttp/wrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ type wrappedHandler struct {
wrapped http.Handler
mwName string
}

func (w *wrappedHandler) String() string {
if h, ok := w.wrapped.(*wrappedHandler); ok {
return w.mwName + "(" + h.String() + ")"
}

return "handler"
}
3 changes: 2 additions & 1 deletion web/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ func TestDefaultService(t *testing.T) {
service.Trace("/albums", postAlbums(), nethttp.SuccessStatus(http.StatusCreated))
service.Options("/albums", postAlbums(), nethttp.SuccessStatus(http.StatusCreated))
service.Docs("/docs", func(title, schemaURL, basePath string) http.Handler {
return nil
// Mount github.com/swaggest/swgui/v4emb.New here.
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})
})

rw := httptest.NewRecorder()
Expand Down

0 comments on commit a0cbf55

Please sign in to comment.