From df87a82d5a650ab63c8772419a05d976d59c1d76 Mon Sep 17 00:00:00 2001 From: RW Date: Wed, 17 May 2023 10:51:05 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20mount=20route=20positionin?= =?UTF-8?q?g=20(#2463)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🐛 [Bug-fix]: Wrong handlers execution order in some mount cases #2460 * 🐛 [Bug-fix]: Wrong handlers execution order in some mount cases #2460 * 🐛 [Bug-fix]: Wrong handlers execution order in some mount cases #2460 * [Bug-fix]: Wrong handlers execution order in some mount cases #2460 * [Bug-fix]: Wrong handlers execution order in some mount cases #2460 --- mount.go | 9 ++-- mount_test.go | 124 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 101 insertions(+), 32 deletions(-) diff --git a/mount.go b/mount.go index 71a7e61dc1..abb5695e9f 100644 --- a/mount.go +++ b/mount.go @@ -174,27 +174,24 @@ func (app *App) processSubAppsRoutes() { } } var handlersCount uint32 + var routePos uint32 // Iterate over the stack of the parent app for m := range app.stack { - // Keep track of the position shift caused by adding routes for mounted apps - var positionShift uint32 // Iterate over each route in the stack stackLen := len(app.stack[m]) for i := 0; i < stackLen; i++ { route := app.stack[m][i] // Check if the route has a mounted app if !route.mount { + routePos++ // If not, update the route's position and continue - route.pos += positionShift + route.pos = routePos if !route.use || (route.use && m == 0) { handlersCount += uint32(len(route.Handlers)) } continue } - // Update the position shift to account for the mounted app's routes - positionShift = route.pos - // Create a slice to hold the sub-app's routes subRoutes := make([]*Route, len(route.group.app.stack[m])) diff --git a/mount_test.go b/mount_test.go index 8481cf54e5..2ae67fb142 100644 --- a/mount_test.go +++ b/mount_test.go @@ -8,6 +8,7 @@ package fiber import ( "errors" "io" + "net/http" "net/http/httptest" "testing" @@ -25,7 +26,7 @@ func Test_App_Mount(t *testing.T) { app := New() app.Mount("/john", micro) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/john/doe", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/john/doe", http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") utils.AssertEqual(t, uint32(2), app.handlersCount) @@ -45,7 +46,7 @@ func Test_App_Mount_RootPath_Nested(t *testing.T) { dynamic.Mount("/api", apiserver) app.Mount("/", dynamic) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/v1/home", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/v1/home", http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") utils.AssertEqual(t, uint32(2), app.handlersCount) @@ -75,15 +76,15 @@ func Test_App_Mount_Nested(t *testing.T) { return c.SendStatus(StatusOK) }) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/one/doe", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/one/doe", http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/nested", nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/nested", http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/three/test", nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/three/test", http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") @@ -99,12 +100,13 @@ func Test_App_Mount_Express_Behavior(t *testing.T) { return c.SendString(body) } } - testEndpoint := func(app *App, route, expectedBody string) { - resp, err := app.Test(httptest.NewRequest(MethodGet, route, nil)) + testEndpoint := func(app *App, route, expectedBody string, expectedStatusCode int) { + resp, err := app.Test(httptest.NewRequest(MethodGet, route, http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, expectedBody, string(body), "Response body") + utils.AssertEqual(t, expectedStatusCode, resp.StatusCode, "Status code") + utils.AssertEqual(t, expectedBody, string(body), "Unexpected response body") } app := New() @@ -130,16 +132,86 @@ func Test_App_Mount_Express_Behavior(t *testing.T) { }) } // expectation check - testEndpoint(app, "/world", "subapp world!") - testEndpoint(app, "/hello", "app hello!") - testEndpoint(app, "/bar", "subapp bar!") - testEndpoint(app, "/foo", "subapp foo!") - testEndpoint(app, "/unknown", ErrNotFound.Message) + testEndpoint(app, "/world", "subapp world!", StatusOK) + testEndpoint(app, "/hello", "app hello!", StatusOK) + testEndpoint(app, "/bar", "subapp bar!", StatusOK) + testEndpoint(app, "/foo", "subapp foo!", StatusOK) + testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound) utils.AssertEqual(t, uint32(17), app.handlersCount) utils.AssertEqual(t, uint32(16+9), app.routesCount) } +// go test -run Test_App_Mount_RoutePositions +func Test_App_Mount_RoutePositions(t *testing.T) { + t.Parallel() + testEndpoint := func(app *App, route, expectedBody string) { + resp, err := app.Test(httptest.NewRequest(MethodGet, route, http.NoBody)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") + utils.AssertEqual(t, expectedBody, string(body), "Unexpected response body") + } + + app := New() + subApp1 := New() + subApp2 := New() + // app setup + { + app.Use(func(c *Ctx) error { + // set initial value + c.Locals("world", "world") + return c.Next() + }) + app.Mount("/subApp1", subApp1) + app.Use(func(c *Ctx) error { + return c.Next() + }) + app.Get("/bar", func(c *Ctx) error { + return c.SendString("ok") + }) + app.Use(func(c *Ctx) error { + // is overwritten in case the positioning is not correct + c.Locals("world", "hello") + return c.Next() + }) + methods := subApp2.Group("/subApp2") + methods.Get("/world", func(c *Ctx) error { + v, ok := c.Locals("world").(string) + if !ok { + panic("unexpected data type") + } + return c.SendString(v) + }) + app.Mount("", subApp2) + } + + testEndpoint(app, "/subApp2/world", "hello") + + routeStackGET := app.Stack()[0] + utils.AssertEqual(t, true, routeStackGET[0].use) + utils.AssertEqual(t, "/", routeStackGET[0].path) + + utils.AssertEqual(t, true, routeStackGET[1].use) + utils.AssertEqual(t, "/", routeStackGET[1].path) + utils.AssertEqual(t, true, routeStackGET[0].pos < routeStackGET[1].pos, "wrong position of route 0") + + utils.AssertEqual(t, false, routeStackGET[2].use) + utils.AssertEqual(t, "/bar", routeStackGET[2].path) + utils.AssertEqual(t, true, routeStackGET[1].pos < routeStackGET[2].pos, "wrong position of route 1") + + utils.AssertEqual(t, true, routeStackGET[3].use) + utils.AssertEqual(t, "/", routeStackGET[3].path) + utils.AssertEqual(t, true, routeStackGET[2].pos < routeStackGET[3].pos, "wrong position of route 2") + + utils.AssertEqual(t, false, routeStackGET[4].use) + utils.AssertEqual(t, "/subapp2/world", routeStackGET[4].path) + utils.AssertEqual(t, true, routeStackGET[3].pos < routeStackGET[4].pos, "wrong position of route 3") + + utils.AssertEqual(t, 5, len(routeStackGET)) +} + // go test -run Test_App_MountPath func Test_App_MountPath(t *testing.T) { t.Parallel() @@ -174,7 +246,7 @@ func Test_App_ErrorHandler_GroupMount(t *testing.T) { v1 := app.Group("/v1") v1.Mount("/john", micro) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) testErrorResponse(t, err, resp, "1: custom error") } @@ -194,7 +266,7 @@ func Test_App_ErrorHandler_GroupMountRootLevel(t *testing.T) { v1 := app.Group("/v1") v1.Mount("/", micro) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) testErrorResponse(t, err, resp, "1: custom error") } @@ -210,7 +282,7 @@ func Test_App_Group_Mount(t *testing.T) { v1 := app.Group("/v1") v1.Mount("/john", micro) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") utils.AssertEqual(t, uint32(2), app.handlersCount) @@ -231,7 +303,7 @@ func Test_App_UseParentErrorHandler(t *testing.T) { app.Mount("/api", fiber) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody)) testErrorResponse(t, err, resp, "hi, i'm a custom error") } @@ -250,7 +322,7 @@ func Test_App_UseMountedErrorHandler(t *testing.T) { app.Mount("/api", fiber) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody)) testErrorResponse(t, err, resp, "hi, i'm a custom error") } @@ -269,7 +341,7 @@ func Test_App_UseMountedErrorHandlerRootLevel(t *testing.T) { app.Mount("/", fiber) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody)) testErrorResponse(t, err, resp, "hi, i'm a custom error") } @@ -311,7 +383,7 @@ func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) { app.Mount("/api", fiber) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", http.NoBody)) utils.AssertEqual(t, nil, err, "/api/sub req") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") @@ -319,7 +391,7 @@ func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) { utils.AssertEqual(t, nil, err, "iotuil.ReadAll()") utils.AssertEqual(t, "hi, i'm a custom sub fiber error", string(b), "Response body") - resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", nil)) + resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", http.NoBody)) utils.AssertEqual(t, nil, err, "/api/sub/third req") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") @@ -345,7 +417,7 @@ func Test_Ctx_Render_Mount(t *testing.T) { app := New() app.Mount("/hello", sub) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/hello/a", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/hello/a", http.NoBody)) utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -397,7 +469,7 @@ func Test_Ctx_Render_Mount_ParentOrSubHasViews(t *testing.T) { sub.Mount("/bruh", sub2) app.Mount("/hello", sub) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/hello/world/a", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/hello/world/a", http.NoBody)) utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -405,7 +477,7 @@ func Test_Ctx_Render_Mount_ParentOrSubHasViews(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "

Hello a!

", string(body)) - resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -413,7 +485,7 @@ func Test_Ctx_Render_Mount_ParentOrSubHasViews(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "

Hello, World!

", string(body)) - resp, err = app.Test(httptest.NewRequest(MethodGet, "/hello/bruh/moment", nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/hello/bruh/moment", http.NoBody)) utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") utils.AssertEqual(t, nil, err, "app.Test(req)") @@ -439,7 +511,7 @@ func Test_Ctx_Render_MountGroup(t *testing.T) { v1 := app.Group("/v1") v1.Mount("/john", micro) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code")