diff --git a/middleware.go b/middleware.go index cb51c565..a79815ca 100644 --- a/middleware.go +++ b/middleware.go @@ -32,6 +32,22 @@ func (r *Router) useInterface(mw middleware) { r.middlewares = append(r.middlewares, mw) } +// RouteMiddleware ------------------------------------------------------------- + +// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Route. Route middleware are executed after the Router middleware but before the Route handler. +func (r *Route) Use(mwf ...MiddlewareFunc) *Route { + for _, fn := range mwf { + r.middlewares = append(r.middlewares, fn) + } + + return r +} + +// useInterface appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Route. Route middleware are executed after the Router middleware but before the Route handler. +func (r *Route) useInterface(mw middleware) { + r.middlewares = append(r.middlewares, mw) +} + // CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header // on requests for routes that have an OPTIONS method matcher to all the method matchers on // the route. Routes that do not explicitly handle OPTIONS requests will not be processed diff --git a/middleware_test.go b/middleware_test.go index 4963b66f..8835b56c 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -2,6 +2,7 @@ package mux import ( "bytes" + "fmt" "net/http" "testing" ) @@ -42,6 +43,17 @@ func TestMiddlewareAdd(t *testing.T) { if len(router.middlewares) != 3 { t.Fatal("Middleware function was not added correctly") } + + route := router.HandleFunc("/route", dummyHandler) + route.useInterface(mw) + if len(route.middlewares) != 1 { + t.Fatal("Route middleware function was not added correctly") + } + + route.Use(banalMw) + if len(route.middlewares) != 2 { + t.Fatal("Route middleware function was not added correctly") + } } func TestMiddleware(t *testing.T) { @@ -85,6 +97,24 @@ func TestMiddleware(t *testing.T) { t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled) } }) + + t.Run("regular call using route middleware func", func(t *testing.T) { + router.HandleFunc("/route", dummyHandler).Use(mw.Middleware) + req = newRequest("GET", "/route") + router.ServeHTTP(rw, req) + if mw.timesCalled != 6 { + t.Fatalf("Expected %d calls, but got only %d", 6, mw.timesCalled) + } + }) + + t.Run("regular call using route middleware interface", func(t *testing.T) { + router.HandleFunc("/route", dummyHandler).useInterface(mw) + req = newRequest("GET", "/route") + router.ServeHTTP(rw, req) + if mw.timesCalled != 9 { + t.Fatalf("Expected %d calls, but got only %d", 9, mw.timesCalled) + } + }) } func TestMiddlewareSubrouter(t *testing.T) { @@ -156,13 +186,15 @@ func TestMiddlewareExecution(t *testing.T) { mwStr := []byte("Middleware\n") handlerStr := []byte("Logic\n") - router := NewRouter() - router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { + handlerFunc := func(w http.ResponseWriter, e *http.Request) { _, err := w.Write(handlerStr) if err != nil { t.Fatalf("Failed writing HTTP response: %v", err) } - }) + } + + router := NewRouter() + router.HandleFunc("/", handlerFunc) t.Run("responds normally without middleware", func(t *testing.T) { rw := NewRecorder() @@ -194,6 +226,29 @@ func TestMiddlewareExecution(t *testing.T) { t.Fatal("Middleware + handler response is not what it should be") } }) + + t.Run("responds with handler, middleware and route middleware response", func(t *testing.T) { + routeMwStr := []byte("Route Middleware\n") + rw := NewRecorder() + req := newRequest("GET", "/route") + + router.HandleFunc("/route", handlerFunc).Use(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write(routeMwStr) + if err != nil { + t.Fatalf("Failed writing HTTP response: %v", err) + } + h.ServeHTTP(w, r) + }) + }) + + router.ServeHTTP(rw, req) + expectedString := append(append(mwStr, routeMwStr...), handlerStr...) + if !bytes.Equal(rw.Body.Bytes(), expectedString) { + fmt.Println(rw.Body.String()) + t.Fatal("Middleware + handler response is not what it should be") + } + }) } func TestMiddlewareNotFound(t *testing.T) { diff --git a/route.go b/route.go index b6582dae..3480419a 100644 --- a/route.go +++ b/route.go @@ -27,6 +27,9 @@ type Route struct { // "global" reference to all named routes namedRoutes map[string]*Route + // route specific middleware + middlewares []middleware + // config possibly passed in from `Router` routeConf } @@ -99,7 +102,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { match.Route = r } if match.Handler == nil { - match.Handler = r.handler + match.Handler = r.GetHandlerWithMiddlewares() } // Set variables. @@ -142,6 +145,20 @@ func (r *Route) GetHandler() http.Handler { return r.handler } +// GetHandlerWithMiddleware returns the route handler wrapped in the assigned middlewares. +// If no middlewares are specified, just the handler, if any, is returned. +func (r *Route) GetHandlerWithMiddlewares() http.Handler { + handler := r.handler + + if handler != nil && len(r.middlewares) > 0 { + for i := len(r.middlewares) - 1; i >= 0; i-- { + handler = r.middlewares[i].Middleware(handler) + } + } + + return handler +} + // Name ----------------------------------------------------------------------- // Name sets the name for the route, used to build URLs.