From 14f73cdfe67120b7459ed9e944f5aa3c3490d18e Mon Sep 17 00:00:00 2001 From: wxiaoguang Date: Sat, 8 Jul 2023 16:07:29 +0800 Subject: [PATCH] fix --- modules/web/route.go | 18 +++++++-- modules/web/routemock.go | 61 ++++++++++++++++++++++++++++++ modules/web/routemock_test.go | 70 +++++++++++++++++++++++++++++++++++ 3 files changed, 145 insertions(+), 4 deletions(-) create mode 100644 modules/web/routemock.go create mode 100644 modules/web/routemock_test.go diff --git a/modules/web/route.go b/modules/web/route.go index 8685062a8e5f9..dc87e112ec60c 100644 --- a/modules/web/route.go +++ b/modules/web/route.go @@ -50,7 +50,9 @@ func NewRoute() *Route { // Use supports two middlewares func (r *Route) Use(middlewares ...any) { for _, m := range middlewares { - r.R.Use(toHandlerProvider(m)) + if m != nil { + r.R.Use(toHandlerProvider(m)) + } } } @@ -79,15 +81,23 @@ func (r *Route) getPattern(pattern string) string { } func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) { - handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)) + handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)+1) for _, m := range r.curMiddlewares { - handlerProviders = append(handlerProviders, toHandlerProvider(m)) + if m != nil { + handlerProviders = append(handlerProviders, toHandlerProvider(m)) + } } for _, m := range h { - handlerProviders = append(handlerProviders, toHandlerProvider(m)) + if h != nil { + handlerProviders = append(handlerProviders, toHandlerProvider(m)) + } } middlewares := handlerProviders[:len(handlerProviders)-1] handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP + mockPoint := RouteMockPoint(MockAfterMiddlewares) + if mockPoint != nil { + middlewares = append(middlewares, mockPoint) + } return middlewares, handlerFunc } diff --git a/modules/web/routemock.go b/modules/web/routemock.go new file mode 100644 index 0000000000000..cb41f63b91ab8 --- /dev/null +++ b/modules/web/routemock.go @@ -0,0 +1,61 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + "net/http" + + "code.gitea.io/gitea/modules/setting" +) + +// MockAfterMiddlewares is a general mock point, it's between middlewares and the handler +const MockAfterMiddlewares = "MockAfterMiddlewares" + +var routeMockPoints = map[string]func(next http.Handler) http.Handler{} + +// RouteMockPoint registers a mock point as a middleware for testing, example: +// +// r.Use(web.RouteMockPoint("my-mock-point-1")) +// r.Get("/foo", middleware2, web.RouteMockPoint("my-mock-point-2"), middleware2, handler) +// +// Then use web.RouteMock to mock the route execution. +// It only takes effect in testing mode (setting.IsInTesting == true). +func RouteMockPoint(pointName string) func(next http.Handler) http.Handler { + if !setting.IsInTesting { + return nil + } + routeMockPoints[pointName] = nil + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if h := routeMockPoints[pointName]; h != nil { + h(next).ServeHTTP(w, r) + } else { + next.ServeHTTP(w, r) + } + }) + } +} + +// RouteMock uses the registered mock point to mock the route execution, example: +// +// defer web.RouteMockReset() +// web.RouteMock(web.MockAfterMiddlewares, func(ctx *context.Context) { +// ctx.WriteResponse(...) +// } +// +// Then the mock function will be executed as a middleware at the mock point. +// It only takes effect in testing mode (setting.IsInTesting == true). +func RouteMock(pointName string, h any) { + if _, ok := routeMockPoints[pointName]; !ok { + panic("route mock point not found: " + pointName) + } + routeMockPoints[pointName] = toHandlerProvider(h) +} + +// RouteMockReset resets all mock points (no mock anymore) +func RouteMockReset() { + for k := range routeMockPoints { + routeMockPoints[k] = nil // keep the keys because RouteMock will check the keys to make sure no misspelling + } +} diff --git a/modules/web/routemock_test.go b/modules/web/routemock_test.go new file mode 100644 index 0000000000000..04c6d1d82e576 --- /dev/null +++ b/modules/web/routemock_test.go @@ -0,0 +1,70 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + "net/http" + "net/http/httptest" + "testing" + + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" +) + +func TestRouteMock(t *testing.T) { + setting.IsInTesting = true + + r := NewRoute() + middleware1 := func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-Middleware1", "m1") + } + middleware2 := func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-Middleware2", "m2") + } + handler := func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-Handler", "h") + } + r.Get("/foo", middleware1, RouteMockPoint("mock-point"), middleware2, handler) + + // normal request + recorder := httptest.NewRecorder() + req, err := http.NewRequest("GET", "http://localhost:8000/foo", nil) + assert.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.Len(t, recorder.Header(), 3) + assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) + assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2")) + assert.EqualValues(t, "h", recorder.Header().Get("X-Test-Handler")) + RouteMockReset() + + // mock at "mock-point" + RouteMock("mock-point", func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-MockPoint", "a") + resp.WriteHeader(http.StatusOK) + }) + recorder = httptest.NewRecorder() + req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil) + assert.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.Len(t, recorder.Header(), 2) + assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) + assert.EqualValues(t, "a", recorder.Header().Get("X-Test-MockPoint")) + RouteMockReset() + + // mock at MockAfterMiddlewares + RouteMock(MockAfterMiddlewares, func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-MockPoint", "b") + resp.WriteHeader(http.StatusOK) + }) + recorder = httptest.NewRecorder() + req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil) + assert.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.Len(t, recorder.Header(), 3) + assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) + assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2")) + assert.EqualValues(t, "b", recorder.Header().Get("X-Test-MockPoint")) + RouteMockReset() +}