diff --git a/fox.go b/fox.go index a375c5a..e842782 100644 --- a/fox.go +++ b/fox.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "path" + "regexp" "strings" "sync" "sync/atomic" @@ -16,7 +17,12 @@ import ( const verb = 4 -var commonVerbs = [verb]string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete} +var ( + // regEnLetter matches english letters for http method name. + regEnLetter = regexp.MustCompile("^[A-Z]+$") + // commonVerbs define http method for which node are pre instantiated. + commonVerbs = [verb]string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete} +) // HandlerFunc is a function type that responds to an HTTP request. // It enforces the same contract as http.Handler but provides additional feature @@ -71,7 +77,7 @@ func New(opts ...Option) *Router { r.noRoute = applyMiddleware(NoRouteHandler, r.mws, r.noRoute) r.noMethod = applyMiddleware(NoMethodHandler, r.mws, r.noMethod) - r.tsrRedirect = applyMiddleware(RedirectHandler, r.mws, defaultRedirectTrailingSlash) + r.tsrRedirect = applyMiddleware(RedirectHandler, r.mws, defaultRedirectTrailingSlashHandler) r.autoOptions = applyMiddleware(OptionsHandler, r.mws, r.autoOptions) r.tree.Store(r.NewTree()) @@ -159,24 +165,38 @@ func (fox *Router) Remove(method, path string) error { return t.Remove(method, path) } -// Reverse perform a lookup on the tree for the given method and path and return the matching registered route if any. -// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Lookup performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a ContextCloser, +// and a boolean indicating if a trailing slash redirect is recommended. The ContextCloser should always be closed if non-nil. +// This method is primarily intended for integrating the fox router into custom routing solutions. It requires the use of the +// original http.ResponseWriter, typically obtained from ServeHTTP. This function is safe for concurrent use by multiple goroutine +// and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func Reverse(t *Tree, method, path string) string { - nds := *t.nodes.Load() - index := findRootNode(method, nds) +func (fox *Router) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { + tree := fox.tree.Load() + + nds := *tree.nodes.Load() + index := findRootNode(r.Method, nds) + if index < 0 { - return "" + return } - c := t.ctx.Get().(*context) - c.resetNil() - n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) - c.Close() - if n == nil { - return "" + c := tree.ctx.Get().(*context) + c.Reset(fox, w, r) + + target := r.URL.Path + if len(r.URL.RawPath) > 0 { + // Using RawPath to prevent unintended match (e.g. /search/a%2Fb/1) + target = r.URL.RawPath } - return n.path + + n, tsr := tree.lookup(nds[index], target, c.params, c.skipNds, false) + if n != nil { + c.path = n.path + return n.handler, c, tsr + } + c.Close() + return nil, nil, tsr } // SkipMethod is used as a return value from WalkFunc to indicate that @@ -226,7 +246,7 @@ func DefaultMethodNotAllowedHandler() HandlerFunc { } } -func defaultRedirectTrailingSlash(c Context) { +func defaultRedirectTrailingSlashHandler(c Context) { req := c.Request() code := http.StatusMovedPermanently diff --git a/fox_test.go b/fox_test.go index 41f7b14..bafa77f 100644 --- a/fox_test.go +++ b/fox_test.go @@ -12,7 +12,6 @@ import ( "net/http/httptest" "reflect" "regexp" - "sort" "strings" "sync" "sync/atomic" @@ -1746,13 +1745,16 @@ func TestTree_RemoveRoot(t *testing.T) { } func TestTree_Methods(t *testing.T) { - tree := New().Tree() - methods := []string{"GET", "POST", "PATCH"} - for _, m := range methods { - require.NoError(t, tree.Handle(m, "/foo/bar", emptyHandler)) + f := New() + for _, rte := range githubAPI { + require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler)) } - sort.Strings(methods) - assert.Equal(t, methods, tree.Methods()) + + methods := f.Tree().Methods("/gists/123/star") + assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) + + methods = f.Tree().Methods("*") + assert.Equal(t, []string{"DELETE", "GET", "POST", "PUT"}, methods) } func TestRouterWithAllowedMethod(t *testing.T) { @@ -1862,7 +1864,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) { target: "/foo", path: "/foo", methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, - want: "GET, PUT, TRACE, OPTIONS", + want: "GET, OPTIONS, PUT, TRACE", wantCode: http.StatusNoContent, }, { @@ -1878,7 +1880,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { - c.SetHeader("Allow", strings.Join(c.Tree().LookupMethods(c.Request().URL.Path), ", ")) + c.SetHeader("Allow", strings.Join(c.Tree().Methods(c.Request().URL.Path), ", ")) c.Writer().WriteHeader(http.StatusNoContent) })) } @@ -1887,7 +1889,6 @@ func TestRouterWithAutomaticOptions(t *testing.T) { f.ServeHTTP(w, req) assert.Equal(t, tc.wantCode, w.Code) assert.Equal(t, tc.want, w.Header().Get("Allow")) - // Reset f.Swap(f.NewTree()) }) @@ -1959,16 +1960,16 @@ func TestWithNotFoundHandler(t *testing.T) { assert.Equal(t, "NOT FOUND\n", w.Body.String()) } -func TestTree_Lookup(t *testing.T) { +func TestRouter_Lookup(t *testing.T) { rx := regexp.MustCompile("({|\\*{)[A-z]+[}]") f := New() for _, rte := range githubAPI { require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler)) } - tree := f.Tree() for _, rte := range githubAPI { - handler, cc, _ := tree.Lookup(rte.method, rte.path, false) + req := httptest.NewRequest(rte.method, rte.path, nil) + handler, cc, _ := f.Lookup(mockResponseWriter{}, req) require.NotNil(t, cc) assert.NotNil(t, handler) @@ -1986,19 +1987,21 @@ func TestTree_Lookup(t *testing.T) { cc.Close() } -} -func TestTree_LookupMethods(t *testing.T) { - f := New() - for _, rte := range githubAPI { - require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler)) - } + // No method match + req := httptest.NewRequest("ANY", "/bar", nil) + handler, cc, _ := f.Lookup(mockResponseWriter{}, req) + assert.Nil(t, handler) + assert.Nil(t, cc) - methods := f.Tree().LookupMethods("/gists/123/star") - assert.Equal(t, []string{"GET", "PUT", "DELETE"}, methods) + // No path match + req = httptest.NewRequest(http.MethodGet, "/bar", nil) + handler, cc, _ = f.Lookup(mockResponseWriter{}, req) + assert.Nil(t, handler) + assert.Nil(t, cc) } -func TestHas(t *testing.T) { +func TestTree_Has(t *testing.T) { routes := []string{ "/foo/bar", "/welcome/{name}", @@ -2052,7 +2055,7 @@ func TestHas(t *testing.T) { } } -func TestReverse(t *testing.T) { +func TestTree_Match(t *testing.T) { routes := []string{ "/foo/bar", "/welcome/{name}", @@ -2092,7 +2095,7 @@ func TestReverse(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, Reverse(r.Tree(), http.MethodGet, tc.path)) + assert.Equal(t, tc.want, r.Tree().Match(http.MethodGet, tc.path)) }) } } @@ -2446,20 +2449,14 @@ func ExampleRouter_Tree() { // This example demonstrates how to create a custom middleware that cleans the request path and performs a manual // lookup on the tree. If the cleaned path matches a registered route, the client is redirected with a 301 status // code (Moved Permanently). -func ExampleTree_Lookup() { +func ExampleTree_Match() { redirectFixedPath := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return func(c Context) { req := c.Request() cleanedPath := CleanPath(req.URL.Path) - handler, cc, _ := c.Tree().Lookup(req.Method, cleanedPath, true) - // You should always close a non-nil Context. - if cc != nil { - defer cc.Close() - } - - // 301 redirect and returns. - if handler != nil { + if match := c.Tree().Match(req.Method, cleanedPath); match != "" { + // 301 redirect and returns. req.URL.Path = cleanedPath http.Redirect(c.Writer(), req, req.URL.String(), http.StatusMovedPermanently) return diff --git a/recovery.go b/recovery.go index b2a0299..2c557f4 100644 --- a/recovery.go +++ b/recovery.go @@ -54,6 +54,7 @@ func recovery(c Context, handle RecoveryFunc) { } func connIsBroken(err any) bool { + //goland:noinspection GoTypeAssertionOnErrors if ne, ok := err.(*net.OpError); ok { var se *os.SyscallError if errors.As(ne, &se) { diff --git a/tree.go b/tree.go index 758a41e..e8c8f35 100644 --- a/tree.go +++ b/tree.go @@ -42,6 +42,10 @@ type Tree struct { // for serving requests. However, this function is NOT thread-safe and should be run serially, along with all other // Tree APIs that perform write operations. To override an existing route, use Update. func (t *Tree) Handle(method, path string, handler HandlerFunc) error { + if matched := regEnLetter.MatchString(method); !matched { + return fmt.Errorf("%w: missing or invalid http method", ErrInvalidRoute) + } + p, catchAllKey, n, err := parseRoute(path) if err != nil { return err @@ -55,6 +59,10 @@ func (t *Tree) Handle(method, path string, handler HandlerFunc) error { // serving requests. However, this function is NOT thread-safe and should be run serially, along with all other // Tree APIs that perform write operations. To add new handler, use Handle method. func (t *Tree) Update(method, path string, handler HandlerFunc) error { + if method == "" { + return fmt.Errorf("%w: missing http method", ErrInvalidRoute) + } + p, catchAllKey, _, err := parseRoute(path) if err != nil { return err @@ -68,104 +76,100 @@ func (t *Tree) Update(method, path string, handler HandlerFunc) error { // However, this function is NOT thread-safe and should be run serially, along with all other Tree APIs that perform // write operations. func (t *Tree) Remove(method, path string) error { + if method == "" { + return fmt.Errorf("%w: missing http method", ErrInvalidRoute) + } + path, _, _, err := parseRoute(path) if err != nil { return err } if !t.remove(method, path) { - return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) + return fmt.Errorf("%w: route %s %s is not registered", ErrRouteNotFound, method, path) } return nil } -// Methods returns a sorted slice of HTTP methods that are currently in use to route requests. -// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Has allows to check if the given method and path exactly match a registered route. This function is safe for +// concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Methods() []string { - var methods []string +func (t *Tree) Has(method, path string) bool { nds := *t.nodes.Load() - for i := range nds { - if len(nds[i].children) > 0 { - if methods == nil { - methods = make([]string, 0) - } - methods = append(methods, nds[i].key) - } + index := findRootNode(method, nds) + if index < 0 { + return false } - sort.Strings(methods) - return methods + + c := t.ctx.Get().(*context) + c.resetNil() + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() + return n != nil && n.path == path } -// Lookup allow to do manual lookup of a route for the given method and path and return the matched HandlerFunc -// along with a ContextCloser and trailing slash redirect recommendation. If lazy is set to true, wildcard parameter are -// not parsed. You should always close the ContextCloser if NOT nil by calling cc.Close(). Note that the returned -// ContextCloser does not have a router, request and response writer attached (use the Reset method). +// Match perform a lookup on the tree for the given method and path and return the matching registered route if any. // This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Lookup(method, path string, lazy bool) (handler HandlerFunc, cc ContextCloser, tsr bool) { +func (t *Tree) Match(method, path string) string { nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return + return "" } c := t.ctx.Get().(*context) c.resetNil() - n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, lazy) - if n != nil { - c.path = n.path - return n.handler, c, tsr + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() + if n == nil { + return "" } - return nil, c, tsr + return n.path } -// LookupMethods lookup and returns all HTTP methods associated with a route that match the given path. This function -// is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Methods returns a sorted list of HTTP methods associated with a given path in the routing tree. If the path is "*", +// it returns all HTTP methods that have at least one route registered in the tree. For a specific path, it returns the methods +// that can route requests to that path. +// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) LookupMethods(path string) (methods []string) { +func (t *Tree) Methods(path string) []string { + var methods []string nds := *t.nodes.Load() - c := t.ctx.Get().(*context) - c.resetNil() - - methods = make([]string, 0) - for i := 0; i < len(nds); i++ { - n, _ := t.lookup(nds[i], path, c.params, c.skipNds, true) - if n != nil { - methods = append(methods, nds[i].key) + if path == "*" { + for i := range nds { + if len(nds[i].children) > 0 { + if methods == nil { + methods = make([]string, 0) + } + methods = append(methods, nds[i].key) + } } - } - c.Close() - return methods -} - -// Has allows to check if the given method and path exactly match a registered route. This function is safe for -// concurrent use by multiple goroutine and while mutation on Tree are ongoing. -// This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Has(method, path string) bool { - nds := *t.nodes.Load() - index := findRootNode(method, nds) - if index < 0 { - return false + } else { + c := t.ctx.Get().(*context) + c.resetNil() + for i := range nds { + n, _ := t.lookup(nds[i], path, c.params, c.skipNds, true) + if n != nil { + if methods == nil { + methods = make([]string, 0) + } + methods = append(methods, nds[i].key) + } + } + c.Close() } - c := t.ctx.Get().(*context) - c.resetNil() - n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) - c.Close() - return n != nil && n.path == path + sort.Strings(methods) + return methods } // Insert is not safe for concurrent use. The path must start by '/' and it's not validated. Use // parseRoute before. func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler HandlerFunc) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. - if method == "" { - return fmt.Errorf("%w: http method is missing", ErrInvalidRoute) - } - var rootNode *node nds := *t.nodes.Load() index := findRootNode(method, nds) @@ -190,7 +194,7 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler if result.matched.isCatchAll() && isCatchAll { return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) } - return fmt.Errorf("%w: new route [%s] %s conflict with %s", ErrRouteExist, method, appendCatchAll(path, catchAllKey), result.matched.path) + return fmt.Errorf("%w: new route %s %s conflict with %s", ErrRouteExist, method, appendCatchAll(path, catchAllKey), result.matched.path) } // We are updating an existing node. We only need to create a new node from @@ -347,12 +351,12 @@ func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFu nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) + return fmt.Errorf("%w: route %s %s is not registered", ErrRouteNotFound, method, path) } result := t.search(nds[index], path) if !result.isExactMatch() || !result.matched.isLeaf() { - return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) + return fmt.Errorf("%w: route %s %s is not registered", ErrRouteNotFound, method, path) } if catchAllKey != result.matched.catchAllKey {