diff --git a/fox.go b/fox.go index c82b525..afa4312 100644 --- a/fox.go +++ b/fox.go @@ -92,6 +92,7 @@ func New(opts ...Option) *Router { func (fox *Router) NewTree() *Tree { tree := new(Tree) tree.mws = fox.mws + tree.ignoreTrailingSlash = fox.ignoreTrailingSlash // Pre instantiate nodes for common http verb nds := make([]*node, len(commonVerbs)) @@ -315,17 +316,17 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.Method != http.MethodConnect && r.URL.Path != "/" && tsr { - if fox.redirectTrailingSlash && target == CleanPath(target) { - // Reset params as it may have recorded wildcard segment (the context may still be used in a middleware) - *c.params = (*c.params)[:0] - fox.tsrRedirect(c) + if fox.ignoreTrailingSlash { + c.path = n.path + n.handler(c) c.Close() return } - if fox.ignoreTrailingSlash { - c.path = n.path - n.handler(c) + if fox.redirectTrailingSlash && target == CleanPath(target) { + // Reset params as it may have recorded wildcard segment (the context may still be used in a middleware) + *c.params = (*c.params)[:0] + fox.tsrRedirect(c) c.Close() return } diff --git a/fox_test.go b/fox_test.go index ffda22e..8d8c6fc 100644 --- a/fox_test.go +++ b/fox_test.go @@ -1539,34 +1539,39 @@ func TestParseRoute(t *testing.T) { func TestTree_LookupTsr(t *testing.T) { cases := []struct { - name string - paths []string - key string - want bool + name string + paths []string + key string + want bool + wantPath string }{ { - name: "match mid edge", - paths: []string{"/foo/bar/"}, - key: "/foo/bar", - want: true, + name: "match mid edge", + paths: []string{"/foo/bar/"}, + key: "/foo/bar", + want: true, + wantPath: "/foo/bar/", }, { - name: "incomplete match end of edge", - paths: []string{"/foo/bar"}, - key: "/foo/bar/", - want: true, + name: "incomplete match end of edge", + paths: []string{"/foo/bar"}, + key: "/foo/bar/", + want: true, + wantPath: "/foo/bar", }, { - name: "match mid edge with child node", - paths: []string{"/users/", "/users/{id}"}, - key: "/users", - want: true, + name: "match mid edge with child node", + paths: []string{"/users/", "/users/{id}"}, + key: "/users", + want: true, + wantPath: "/users/", }, { - name: "match mid edge in child node", - paths: []string{"/users", "/users/{id}"}, - key: "/users/", - want: true, + name: "match mid edge in child node", + paths: []string{"/users", "/users/{id}"}, + key: "/users/", + want: true, + wantPath: "/users", }, { name: "match mid edge in child node with invalid remaining prefix", @@ -1611,7 +1616,128 @@ func TestTree_LookupTsr(t *testing.T) { n, got := tree.lookup(nds[0], tc.key, c.params, c.skipNds, true) assert.Equal(t, tc.want, got) if tc.want { - assert.NotNil(t, n) + require.NotNil(t, n) + assert.Equal(t, tc.wantPath, n.path) + } + }) + } +} + +func TestRouterWithIgnoreTrailingSlash(t *testing.T) { + cases := []struct { + name string + paths []string + req string + method string + wantCode int + wantPath string + }{ + { + name: "current not a leaf with extra ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo", + }, + { + name: "current not a leaf and path does not end with ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/c", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "current not a leaf and path end with extra char and ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/c/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "current not a leaf and path end with ts but last is not a leaf", + paths: []string{"/foo/a/a", "/foo/a/b", "/foo/c/"}, + req: "/foo/a/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "mid edge key with extra ts", + paths: []string{"/foo/bar/"}, + req: "/foo/bar", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo/bar/", + }, + { + name: "mid edge key with without extra ts", + paths: []string{"/foo/bar/baz", "/foo/bar"}, + req: "/foo/bar/", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo/bar", + }, + { + name: "mid edge key without extra ts", + paths: []string{"/foo/bar/baz", "/foo/bar"}, + req: "/foo/bar/", + method: http.MethodPost, + wantCode: http.StatusOK, + wantPath: "/foo/bar", + }, + { + name: "incomplete match end of edge", + paths: []string{"/foo/bar"}, + req: "/foo/bar/", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo/bar", + }, + { + name: "match mid edge with ts and more char after", + paths: []string{"/foo/bar/buzz"}, + req: "/foo/bar", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "match mid edge with ts and more char before", + paths: []string{"/foo/barr/"}, + req: "/foo/bar", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "incomplete match end of edge with ts and more char after", + paths: []string{"/foo/bar"}, + req: "/foo/bar/buzz", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "incomplete match end of edge with ts and more char before", + paths: []string{"/foo/bar"}, + req: "/foo/barr/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := New(WithIgnoreTrailingSlash(true)) + for _, path := range tc.paths { + require.NoError(t, r.Tree().Handle(tc.method, path, func(c Context) { + _ = c.String(http.StatusOK, c.Path()) + })) + } + + req := httptest.NewRequest(tc.method, tc.req, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, tc.wantCode, w.Code) + if tc.wantPath != "" { + assert.Equal(t, tc.wantPath, w.Body.String()) } }) } @@ -1815,6 +1941,27 @@ func TestTree_Methods(t *testing.T) { methods = f.Tree().Methods("*") assert.Equal(t, []string{"DELETE", "GET", "POST", "PUT"}, methods) + + // Ignore trailing slash disable + methods = f.Tree().Methods("/gists/123/star/") + assert.Empty(t, methods) +} + +func TestTree_MethodsWithIgnoreTsEnable(t *testing.T) { + f := New(WithIgnoreTrailingSlash(true)) + for _, method := range []string{"DELETE", "GET", "PUT"} { + require.NoError(t, f.Handle(method, "/foo/bar", emptyHandler)) + require.NoError(t, f.Handle(method, "/john/doe/", emptyHandler)) + } + + methods := f.Tree().Methods("/foo/bar/") + assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) + + methods = f.Tree().Methods("/john/doe") + assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) + + methods = f.Tree().Methods("/foo/bar/baz") + assert.Empty(t, methods) } func TestRouterWithAllowedMethod(t *testing.T) { @@ -1864,6 +2011,91 @@ func TestRouterWithAllowedMethod(t *testing.T) { } } +func TestRouterWithAllowedMethodAndIgnoreTsEnable(t *testing.T) { + r := New(WithNoMethod(true), WithIgnoreTrailingSlash(true)) + + // Support for ignore Trailing slash + cases := []struct { + name string + target string + path string + req string + want string + methods []string + }{ + { + name: "all route except the last one", + methods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead}, + path: "/foo/bar/", + req: "/foo/bar", + target: http.MethodTrace, + want: "GET, POST, PUT, DELETE, PATCH, CONNECT, OPTIONS, HEAD", + }, + { + name: "all route except the first one", + methods: []string{http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead, http.MethodTrace}, + path: "/foo/baz", + req: "/foo/baz/", + target: http.MethodGet, + want: "POST, PUT, DELETE, PATCH, CONNECT, OPTIONS, HEAD, TRACE", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, r.Tree().Handle(method, tc.path, emptyHandler)) + } + req := httptest.NewRequest(tc.target, tc.req, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) + assert.Equal(t, tc.want, w.Header().Get("Allow")) + }) + } +} + +func TestRouterWithAllowedMethodAndIgnoreTsDisable(t *testing.T) { + r := New(WithNoMethod(true)) + + // Support for ignore Trailing slash + cases := []struct { + name string + target string + path string + req string + want int + methods []string + }{ + { + name: "all route except the last one", + methods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead}, + path: "/foo/bar/", + req: "/foo/bar", + target: http.MethodTrace, + }, + { + name: "all route except the first one", + methods: []string{http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead, http.MethodTrace}, + path: "/foo/baz", + req: "/foo/baz/", + target: http.MethodGet, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, r.Tree().Handle(method, tc.path, emptyHandler)) + } + req := httptest.NewRequest(tc.target, tc.req, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) + }) + } +} + func TestRouterWithMethodNotAllowedHandler(t *testing.T) { f := New(WithNoMethodHandler(func(c Context) { c.SetHeader("FOO", "BAR") @@ -1955,6 +2187,126 @@ func TestRouterWithAutomaticOptions(t *testing.T) { } } +func TestRouterWithAutomaticOptionsAndIgnoreTsOptionEnable(t *testing.T) { + f := New(WithAutoOptions(true), WithIgnoreTrailingSlash(true)) + + cases := []struct { + name string + target string + path string + want string + wantCode int + methods []string + }{ + { + name: "system-wide requests", + target: "*", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + want: "GET, PUT, TRACE, OPTIONS", + wantCode: http.StatusOK, + }, + { + name: "system-wide with custom options registered", + target: "*", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, + want: "GET, PUT, TRACE, OPTIONS", + wantCode: http.StatusOK, + }, + { + name: "system-wide requests with empty router", + target: "*", + wantCode: http.StatusNotFound, + }, + { + name: "regular option request and ignore ts", + target: "/foo/", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + want: "GET, PUT, TRACE, OPTIONS", + wantCode: http.StatusOK, + }, + { + name: "regular option request with handler priority and ignore ts", + target: "/foo", + path: "/foo/", + methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, + want: "GET, OPTIONS, PUT, TRACE", + wantCode: http.StatusNoContent, + }, + { + name: "regular option request with no matching route", + target: "/bar", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + wantCode: http.StatusNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { + c.SetHeader("Allow", strings.Join(c.Tree().Methods(c.Request().URL.Path), ", ")) + c.Writer().WriteHeader(http.StatusNoContent) + })) + } + req := httptest.NewRequest(http.MethodOptions, tc.target, nil) + w := httptest.NewRecorder() + f.ServeHTTP(w, req) + assert.Equal(t, tc.wantCode, w.Code) + assert.Equal(t, tc.want, w.Header().Get("Allow")) + // Reset + f.Swap(f.NewTree()) + }) + } +} + +func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) { + f := New(WithAutoOptions(true)) + + cases := []struct { + name string + target string + path string + wantCode int + methods []string + }{ + { + name: "regular option request and ignore ts", + target: "/foo/", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + wantCode: http.StatusNotFound, + }, + { + name: "regular option request with handler priority and ignore ts", + target: "/foo", + path: "/foo/", + methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, + wantCode: http.StatusNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { + c.SetHeader("Allow", strings.Join(c.Tree().Methods(c.Request().URL.Path), ", ")) + c.Writer().WriteHeader(http.StatusNoContent) + })) + } + req := httptest.NewRequest(http.MethodOptions, tc.target, nil) + w := httptest.NewRecorder() + f.ServeHTTP(w, req) + assert.Equal(t, tc.wantCode, w.Code) + // Reset + f.Swap(f.NewTree()) + }) + } +} + func TestRouterWithOptionsHandler(t *testing.T) { f := New(WithOptionsHandler(func(c Context) { assert.Equal(t, "/foo/bar", c.Path()) @@ -2084,7 +2436,7 @@ func TestTree_Has(t *testing.T) { want: true, }, { - name: "no match static route", + name: "no match static route (no tsr)", path: "/foo/bar/", }, { @@ -2115,6 +2467,79 @@ func TestTree_Has(t *testing.T) { } } +func TestTree_HasWithIgnoreTrailingSlashEnable(t *testing.T) { + routes := []string{ + "/foo/bar", + "/welcome/{name}/", + "/users/uid_{id}", + } + + r := New(WithIgnoreTrailingSlash(true)) + for _, rte := range routes { + require.NoError(t, r.Handle(http.MethodGet, rte, emptyHandler)) + } + + cases := []struct { + name string + path string + want bool + }{ + { + name: "strict match static route", + path: "/foo/bar", + want: true, + }, + { + name: "no match static route with tsr", + path: "/foo/bar/", + want: true, + }, + { + name: "strict match route params", + path: "/welcome/{name}/", + want: true, + }, + { + name: "strict match route params with tsr", + path: "/welcome/{name}", + want: true, + }, + { + name: "no match route params with ts", + path: "/welcome/fox", + }, + { + name: "no match route params without ts", + path: "/welcome/fox/", + }, + { + name: "strict match mid route params", + path: "/users/uid_{id}", + want: true, + }, + { + name: "strict match mid route params with tsr", + path: "/users/uid_{id}/", + want: true, + }, + { + name: "no match mid route params without ts", + path: "/users/uid_123", + }, + { + name: "no match mid route params with ts", + path: "/users/uid_123", + }, + } + + tree := r.Tree() + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, tree.Has(http.MethodGet, tc.path)) + }) + } +} + func TestTree_Match(t *testing.T) { routes := []string{ "/foo/bar", @@ -2137,6 +2562,10 @@ func TestTree_Match(t *testing.T) { path: "/foo/bar", want: "/foo/bar", }, + { + name: "reverse static route with tsr disable", + path: "/foo/bar/", + }, { name: "reverse params route", path: "/welcome/fox", @@ -2160,6 +2589,66 @@ func TestTree_Match(t *testing.T) { } } +func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { + routes := []string{ + "/foo/bar", + "/welcome/{name}/", + "/users/uid_{id}", + } + + r := New(WithIgnoreTrailingSlash(true)) + for _, rte := range routes { + require.NoError(t, r.Handle(http.MethodGet, rte, emptyHandler)) + } + + cases := []struct { + name string + path string + want string + }{ + { + name: "reverse static route", + path: "/foo/bar", + want: "/foo/bar", + }, + { + name: "reverse static route with tsr", + path: "/foo/bar/", + want: "/foo/bar", + }, + { + name: "reverse params route", + path: "/welcome/fox/", + want: "/welcome/{name}/", + }, + { + name: "reverse params route with tsr", + path: "/welcome/fox", + want: "/welcome/{name}/", + }, + { + name: "reverse mid params route", + path: "/users/uid_123", + want: "/users/uid_{id}", + }, + { + name: "reverse mid params route with tsr", + path: "/users/uid_123/", + want: "/users/uid_{id}", + }, + { + name: "reverse no match", + path: "/users/fox", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, r.Tree().Match(http.MethodGet, tc.path)) + }) + } +} + func TestEncodedPath(t *testing.T) { encodedPath := "run/cmd/S123L%2FA" req := httptest.NewRequest(http.MethodGet, "/"+encodedPath, nil) diff --git a/options.go b/options.go index 4f9a60c..071a6c3 100644 --- a/options.go +++ b/options.go @@ -114,7 +114,7 @@ func WithAutoOptions(enable bool) Option { // another handler is found with/without an additional trailing slash. E.g. /foo/bar/ request does not match // but /foo/bar would match. The client is redirected with a http status code 301 for GET requests and 308 for // all other methods. Note that this option is mutually exclusive with WithIgnoreTrailingSlash, and if both are -// enabled, WithRedirectTrailingSlash takes precedence. +// enabled, WithIgnoreTrailingSlash takes precedence. func WithRedirectTrailingSlash(enable bool) Option { return optionFunc(func(r *Router) { r.redirectTrailingSlash = enable @@ -124,7 +124,7 @@ func WithRedirectTrailingSlash(enable bool) Option { // WithIgnoreTrailingSlash allows the router to match routes regardless of whether a trailing slash is present or not. // E.g. /foo/bar/ and /foo/bar would both match the same handler. This option prevents the router from issuing // a redirect and instead matches the request directly. Note that this option is mutually exclusive with -// WithRedirectTrailingSlash, and if both are enabled, WithRedirectTrailingSlash takes precedence. +// WithRedirectTrailingSlash, and if both are enabled, WithIgnoreTrailingSlash takes precedence. // This api is EXPERIMENTAL and is likely to change in future release. func WithIgnoreTrailingSlash(enable bool) Option { return optionFunc(func(r *Router) { diff --git a/tree.go b/tree.go index 307dcc2..bae4828 100644 --- a/tree.go +++ b/tree.go @@ -33,8 +33,9 @@ type Tree struct { nodes atomic.Pointer[[]*node] mws []middleware sync.Mutex - maxParams atomic.Uint32 - maxDepth atomic.Uint32 + maxParams atomic.Uint32 + maxDepth atomic.Uint32 + ignoreTrailingSlash bool } // Handle registers a new handler for the given method and path. This function return an error if the route @@ -92,8 +93,9 @@ func (t *Tree) Remove(method, path string) error { return nil } -// Has allows to check if the given method and path exactly match a registered route. This function is safe for -// concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Has allows to check if the given method and path exactly match a registered route. When WithIgnoreTrailingSlash +// is enabled, Has will match a registered route regardless of an extra or missing trailing slash. This function is +// safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Has(method, path string) bool { nds := *t.nodes.Load() @@ -106,10 +108,21 @@ func (t *Tree) Has(method, path string) bool { c.resetNil() n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, true) c.Close() - return !tsr && n != nil && n.path == path + if n == nil { + return false + } + if n.path == path { + return !tsr + } + if tsr && t.ignoreTrailingSlash { + return n.path == fixTrailingSlash(path) + } + + return false } -// Match perform a lookup on the tree for the given method and path and return the matching registered route if any. +// Match perform a lookup on the tree for the given method and path and return the matching registered route if any. When +// WithIgnoreTrailingSlash is enabled, Match will match a registered route regardless of an extra or missing trailing slash. // This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Match(method, path string) string { @@ -123,10 +136,10 @@ func (t *Tree) Match(method, path string) string { c.resetNil() n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, true) c.Close() - if tsr || n == nil { - return "" + if n != nil && (!tsr || t.ignoreTrailingSlash) { + return n.path } - return n.path + return "" } // Methods returns a sorted list of HTTP methods associated with a given path in the routing tree. If the path is "*", @@ -152,7 +165,7 @@ func (t *Tree) Methods(path string) []string { c.resetNil() for i := range nds { n, tsr := t.lookup(nds[i], path, c.params, c.skipNds, true) - if !tsr && n != nil { + if n != nil && (!tsr || t.ignoreTrailingSlash) { if methods == nil { methods = make([]string, 0) }