diff --git a/LICENSE.txt b/LICENSE.txt index 261eeb9..bb9595b 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2022 Sylvain Müller Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 62b1d3f..0e792a9 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,10 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/tigerwill90/fox)](https://goreportcard.com/report/github.com/tigerwill90/fox) [![codecov](https://codecov.io/gh/tigerwill90/fox/branch/master/graph/badge.svg?token=09nfd7v0Bl)](https://codecov.io/gh/tigerwill90/fox) # Fox -Fox is a lightweight high performance HTTP request router for [Go](https://go.dev/). The main difference with other routers is +Fox is a zero allocation, lightweight, high performance HTTP request router for [Go](https://go.dev/). The main difference with other routers is that it supports **mutation on its routing tree while handling request concurrently**. Internally, Fox use a [Concurrent Radix Tree](https://github.com/npgall/concurrent-trees/blob/master/documentation/TreeDesign.md) that support **lock-free -reads** while allowing **concurrent writes**. - -The router tree is optimized for high-concurrency and high performance reads, and low-concurrency write. Fox has a small memory footprint, and -in many case, it does not do a single heap allocation while handling request. +reads** while allowing **concurrent writes**. The router tree is optimized for high-concurrency and high performance reads, and low-concurrency write. Fox supports various use cases, but it is especially designed for applications that require changes at runtime to their routing structure based on user input, configuration changes, or other runtime events. @@ -24,13 +21,14 @@ request! **Wildcard pattern:** Route can be registered using wildcard parameters. The matched path segment can then be easily retrieved by name. Due to Fox design, wildcard route are cheap and scale really well. -**Detect panic:** You can register a fallback handler that is fire in case of panics occurring during handling an HTTP request. +**Detect panic:** Comes with a ready-to-use, efficient Recovery middleware that gracefully handles panics. **Get the current route:** You can easily retrieve the route of the matched request. This actually makes it easier to integrate -observability middleware like open telemetry (disable by default). +observability middleware like open telemetry. -**Only explicit matches:** Inspired from [httprouter](https://github.com/julienschmidt/httprouter), a request can only match -exactly one or no route. As a result there are no unintended matches, and it also encourages good RESTful api design. +**Only explicit matches:** A request can only match exactly one route or no route at all. Fox strikes a balance between routing flexibility, +performance and clarity by enforcing clear priority rules, ensuring that there are no unintended matches and maintaining high performance +even for complex routing pattern. **Redirect trailing slashes:** Inspired from [httprouter](https://github.com/julienschmidt/httprouter), the router automatically redirects the client, at no extra cost, if another route match with or without a trailing slash (disable by default). @@ -51,40 +49,38 @@ go get -u github.com/tigerwill90/fox package main import ( - "fmt" "github.com/tigerwill90/fox" "log" "net/http" ) -var WelcomeHandler = fox.HandlerFunc(func(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprint(w, "Welcome!\n") -}) - -type HelloHandler struct{} +type Greeting struct { + Say string +} -func (h *HelloHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprintf(w, "Hello %s\n", params.Get("name")) +func (h *Greeting) Greet(c fox.Context) { + _ = c.String(http.StatusOK, "%s %s\n", h.Say, c.Param("name")) } func main() { - r := fox.New() - - Must(r.Handler(http.MethodGet, "/", WelcomeHandler)) - Must(r.Handler(http.MethodGet, "/hello/{name}", new(HelloHandler))) - - log.Fatalln(http.ListenAndServe(":8080", r)) -} + r := fox.New(fox.DefaultOptions()) -func Must(err error) { + err := r.Handle(http.MethodGet, "/", func(c fox.Context) { + _ = c.String(http.StatusOK, "Welcome\n") + }) if err != nil { panic(err) } + + h := Greeting{Say: "Hello"} + r.MustHandle(http.MethodGet, "/hello/{name}", h.Greet) + + log.Fatalln(http.ListenAndServe(":8080", r)) } ```` #### Error handling Since new route may be added at any given time, Fox, unlike other router, does not panic when a route is malformed or conflicts with another. -Instead, it returns the following error values +Instead, it returns the following error values: ```go ErrRouteExist = errors.New("route already registered") ErrRouteConflict = errors.New("route conflict") @@ -102,8 +98,8 @@ if errors.Is(err, fox.ErrRouteConflict) { ``` #### Named parameters -A route can be defined using placeholder (e.g `{name}`). The values are accessible via `fox.Params`, which is just a slice of `fox.Param`. -The `Get` method is a helper to retrieve the value using the placeholder name. +A route can be defined using placeholder (e.g `{name}`). The matching segment are recorder into the `fox.Params` slice accessible +via `fox.Context`. The `Param` and `Get` methods are helpers to retrieve the value using the placeholder name. ``` Pattern /avengers/{name} @@ -119,7 +115,7 @@ Pattern /users/uuid:{id} /users/uuid no match ``` -### Catch all parameter +#### Catch all parameter Catch-all parameters can be used to match everything at the end of a route. The placeholder start with `*` followed by a regular named parameter (e.g. `*{name}`). ``` @@ -136,19 +132,36 @@ Patter /src/file=*{path} /src/file=/dir/config.txt match ``` +#### Priority rules +Routes are prioritized based on specificity, with static segments taking precedence over wildcard segments. +A wildcard segment (named parameter or catch all) can only overlap with static segments, for the same HTTP method. +For instance, `GET /users/{id}` and `GET /users/{name}/profile` cannot coexist, as the `{id}` and `{name}` segments +are overlapping. These limitations help to minimize the number of branches that need to be evaluated in order to find +the right match, thereby maintaining high-performance routing. + +For example, the followings route are allowed: +```` +GET /*{filepath} +GET /users/{id} +GET /users/{id}/emails +GET /users/{id}/{actions} +POST /users/{name}/emails +```` + #### Warning about params slice -`fox.Params` slice is freed once ServeHTTP returns and may be reused later to save resource. Therefore, if you need to hold `fox.Params` +`fox.Context` is freed once ServeHTTP returns and may be reused later to save resource. Therefore, if you need to hold `fox.Params` longer, use the `Clone` methods. -```go -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { - p := params.Clone() - go func(){ - time.Sleep(1 * time.Second) - log.Println(p.Get("name")) // Safe - }() - _, _ = fmt.Fprintf(w, "Hello %s\n", params.Get("name")) +````go +func Hello(c fox.Context) { + cc := c.Clone() + // cp := c.Params().Clone() + go func() { + time.Sleep(2 * time.Second) + log.Println(cc.Param("name")) // Safe + }() + _ = c.String(http.StatusOK, "Hello %s\n", c.Param("name")) } -``` +```` ## Concurrency Fox implements a [Concurrent Radix Tree](https://github.com/npgall/concurrent-trees/blob/master/documentation/TreeDesign.md) that supports **lock-free** @@ -178,7 +191,7 @@ As such threads that route requests should never encounter latency due to ongoin In this example, the handler for `routes/{action}` allow to dynamically register, update and remove handler for the given route and method. Thanks to Fox's design, those actions are perfectly safe and may be executed concurrently. -```go +````go package main import ( @@ -190,14 +203,10 @@ import ( "strings" ) -type ActionHandler struct { - fox *fox.Router -} - -func (h *ActionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { +func Action(c fox.Context) { var data map[string]string - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + if err := json.NewDecoder(c.Request().Body).Decode(&data); err != nil { + http.Error(c.Writer(), err.Error(), http.StatusBadRequest) return } @@ -206,51 +215,46 @@ func (h *ActionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, params text := data["text"] if path == "" || method == "" { - http.Error(w, "missing method or path", http.StatusBadRequest) + http.Error(c.Writer(), "missing method or path", http.StatusBadRequest) return } var err error - action := params.Get("action") + action := c.Param("action") switch action { case "add": - err = h.fox.Handler(method, path, fox.HandlerFunc(func(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprintln(w, text) - })) + err = c.Fox().Handle(method, path, func(c fox.Context) { + _, _ = fmt.Fprintln(c.Writer(), text) + }) case "update": - err = h.fox.Update(method, path, fox.HandlerFunc(func(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprintln(w, text) - })) + err = c.Fox().Update(method, path, func(c fox.Context) { + _, _ = fmt.Fprintln(c.Writer(), text) + }) case "delete": - err = h.fox.Remove(method, path) + err = c.Fox().Remove(method, path) default: - http.Error(w, fmt.Sprintf("action %q is not allowed", action), http.StatusBadRequest) + http.Error(c.Writer(), fmt.Sprintf("action %q is not allowed", action), http.StatusBadRequest) return } if err != nil { - http.Error(w, err.Error(), http.StatusConflict) + http.Error(c.Writer(), err.Error(), http.StatusConflict) return } - _, _ = fmt.Fprintf(w, "%s route [%s] %s: success\n", action, method, path) + _, _ = fmt.Fprintf(c.Writer(), "%s route [%s] %s: success\n", action, method, path) } func main() { r := fox.New() - Must(r.Handler(http.MethodPost, "/routes/{action}", &ActionHandler{fox: r})) + r.MustHandle(http.MethodPost, "/routes/{action}", Action) log.Fatalln(http.ListenAndServe(":8080", r)) } - -func Must(err error) { - if err != nil { - panic(err) - } -} -``` +```` #### Tree swapping -Fox also enables you to replace the entire tree in a single atomic operation using the `Use` and `Swap` methods. +Fox also enables you to replace the entire tree in a single atomic operation using the `Swap` methods. Note that router's options apply automatically on the new tree. + ````go package main @@ -269,18 +273,20 @@ type HtmlRenderer struct { Template template.HTML } -func (h *HtmlRenderer) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { - log.Printf("matched route: %s", params.Get(fox.RouteKey)) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, _ = io.Copy(w, strings.NewReader(string(h.Template))) +func (h *HtmlRenderer) Render(c fox.Context) { + log.Printf("matched handler path: %s", c.Path()) + c.Writer().Header().Set(fox.HeaderContentType, fox.MIMETextHTMLCharsetUTF8) + _, _ = io.Copy(c.Writer(), strings.NewReader(string(h.Template))) } func main() { - r := fox.New(fox.WithSaveMatchedRoute(true)) + r := fox.New() routes := db.GetRoutes() + for _, rte := range routes { - Must(r.Handler(rte.Method, rte.Path, &HtmlRenderer{Template: rte.Template})) + h := HtmlRenderer{Template: rte.Template} + r.MustHandle(rte.Method, rte.Path, h.Render) } go Reload(r) @@ -293,22 +299,17 @@ func Reload(r *fox.Router) { routes := db.GetRoutes() tree := r.NewTree() for _, rte := range routes { - if err := tree.Handler(rte.Method, rte.Path, &HtmlRenderer{Template: rte.Template}); err != nil { + h := HtmlRenderer{Template: rte.Template} + if err := tree.Handle(rte.Method, rte.Path, h.Render); err != nil { log.Printf("error reloading route: %s\n", err) continue } } - // Replace the currently in-use routing tree with the new provided. - r.Use(tree) + // Swap the currently in-use routing tree with the new provided. + r.Swap(tree) log.Println("route reloaded") } } - -func Must(err error) { - if err != nil { - panic(err) - } -} ```` #### Advanced usage: consistent view updates @@ -321,21 +322,20 @@ is already registered for the provided method and path. By locking the `Tree`, t atomicity, as it prevents other threads from modifying the tree between the lookup and the write operation. Note that all read operation on the tree remain lock-free. ````go -func Upsert(t *fox.Tree, method, path string, handler fox.Handler) error { +func Upsert(t *fox.Tree, method, path string, handler fox.HandlerFunc) error { t.Lock() defer t.Unlock() if fox.Has(t, method, path) { return t.Update(method, path, handler) } - return t.Handler(method, path, handler) + return t.Handle(method, path, handler) } ```` #### Concurrent safety and proper usage of Tree APIs -Some important consideration to keep in mind when using `Tree` API. Each instance as its own `sync.Mutex` and `sync.Pool` -that may be used to serialize write and reduce memory allocation. Since the router tree may be swapped at any -given time, you **MUST always copy the pointer locally** to avoid inadvertently releasing Params to the wrong pool -or worst, causing a deadlock by locking/unlocking the wrong `Tree`. +Some important consideration to keep in mind when using `Tree` API. Each instance as its own `sync.Mutex` that may be +used to serialize write . Since the router tree may be swapped at any given time, you **MUST always copy the pointer +locally** to avoid inadvertently causing a deadlock by locking/unlocking the wrong `Tree`. ````go // Good @@ -343,25 +343,94 @@ t := r.Tree() t.Lock() defer t.Unlock() -// Dramatically bad, may cause deadlock: +// Dramatically bad, may cause deadlock r.Tree().Lock() defer r.Tree().Unlock() + +// Dramatically bad, may cause deadlock +func handle(c fox.Context) { + c.Fox().Tree().Lock() + defer c.Fox().Tree().Unlock() +} ```` -This principle also applies to the `fox.Lookup` function, which requires releasing the `fox.Params` slice by calling `params.Free(tree)`. -Always ensure that the `Tree` pointer passed as a parameter to `params.Free` is the same as the one passed to the `fox.Lookup` function. +Note that `fox.Context` carries a local copy of the `Tree` that is being used to serve the handler, thereby eliminating +the risk of deadlock when using the `Tree` within the context. +````go +// Ok +func handle(c fox.Context) { + c.Tree().Lock() + defer c.Tree().Unlock() +} +```` ## Working with http.Handler Fox itself implements the `http.Handler` interface which make easy to chain any compatible middleware before the router. Moreover, the router -provides convenient `fox.WrapF` and `fox.WrapH` adapter to be use with `http.Handler`. Named and catch all parameters are forwarded via the -request context +provides convenient `fox.WrapF`, `fox.WrapH` and `fox.WrapM` adapter to be use with `http.Handler`. + +Wrapping an http.Handler ```go -_ = r.Handler(http.MethodGet, "/users/{id}", fox.WrapF(func(w http.ResponseWriter, r *http.Request) { - params := fox.ParamsFromContext(r.Context()) - _, _ = fmt.Fprintf(w, "user id: %s\n", params.Get("id")) -})) +articles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "get articles") +}) + +r := fox.New(fox.DefaultOptions()) +r.MustHandle(http.MethodGet, "/articles", fox.WrapH(httpRateLimiter.RateLimit(articles))) ``` +Wrapping an http.Handler compatible middleware +````go +r := fox.New(fox.DefaultOptions(), fox.WithMiddleware(fox.WrapM(httpRateLimiter.RateLimit))) +r.MustHandle(http.MethodGet, "/articles/{id}", func(c fox.Context) { + _ = c.String(http.StatusOK, "Article id: %s\n", c.Param("id")) +}) +```` + +## Middleware +Middlewares can be registered globally using the `fox.WithMiddleware` option. The example below demonstrates how +to create and apply automatically a simple logging middleware to all route. + +````go +package main + +import ( + "github.com/tigerwill90/fox" + "log" + "net/http" + "time" +) + +var logger = fox.MiddlewareFunc(func(next fox.HandlerFunc) fox.HandlerFunc { + return func(c fox.Context) { + start := time.Now() + next(c) + log.Printf( + "route: %s, latency: %s, status: %d, size: %d", + c.Path(), + time.Since(start), + c.Writer().Status(), + c.Writer().Size(), + ) + } +}) + +func main() { + r := fox.New(fox.WithMiddleware(logger)) + + r.MustHandle(http.MethodGet, "/", func(c fox.Context) { + resp, err := http.Get("https://api.coindesk.com/v1/bpi/currentprice.json") + if err != nil { + http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + _ = c.Stream(http.StatusOK, fox.MIMEApplicationJSON, resp.Body) + }) + + log.Fatalln(http.ListenAndServe(":8080", r)) +} +```` + ## Benchmark The primary goal of Fox is to be a lightweight, high performance router which allow routes modification while in operation. The following benchmarks attempt to compare Fox to various popular alternatives. Some are fully featured web framework, and other @@ -509,15 +578,14 @@ BenchmarkPat_GithubAll 550 21177 ## Road to v1 - [x] [Update route syntax](https://github.com/tigerwill90/fox/pull/10#issue-1643728309) @v0.6.0 -- [ ] [Route overlapping](https://github.com/tigerwill90/fox/pull/9#issue-1642887919) @v0.7.0 -- [ ] Collect feedback and polishing +- [x] [Route overlapping](https://github.com/tigerwill90/fox/pull/9#issue-1642887919) @v0.7.0 +- [ ] Improving performance and polishing ## Contributions This project aims to provide a lightweight, high performance and easy to use http router. It purposely has a limited set of features and exposes a relatively low-level api. -The intention behind these choices is that it can serve as a building block for more "batteries included" frameworks. Feature requests and PRs along these lines are welcome. +The intention behind these choices is that it can serve as a building block for implementing your own "batteries included" frameworks. Feature requests and PRs along these lines are welcome. ## Acknowledgements - [npgall/concurrent-trees](https://github.com/npgall/concurrent-trees): Fox design is largely inspired from Niall Gallagher's Concurrent Trees design. -- [julienschmidt/httprouter](https://github.com/julienschmidt/httprouter): a lot of feature that implements Fox are inspired from Julien Schmidt's router. - -## RFC route overlapping enhancement \ No newline at end of file +- [julienschmidt/httprouter](https://github.com/julienschmidt/httprouter): some feature that implements Fox are inspired from Julien Schmidt's router. Most notably, +this package uses the optimized [httprouter.Cleanpath](https://github.com/julienschmidt/httprouter/blob/master/path.go) function. diff --git a/context.go b/context.go new file mode 100644 index 0000000..78a90cb --- /dev/null +++ b/context.go @@ -0,0 +1,269 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + +package fox + +import ( + "fmt" + "io" + "net/http" + "net/url" +) + +// ContextCloser extends Context for manually created instances, adding a Close method +// to release resources after use. +type ContextCloser interface { + Context + Close() +} + +// Context represents the context of the current HTTP request. +// It provides methods to access request data and to write a response. +type Context interface { + // Done returns a channel that closes when the request's context is + // cancelled or times out. + Done() <-chan struct{} + // Request returns the current *http.Request. + Request() *http.Request + // SetRequest sets the *http.Request. + SetRequest(r *http.Request) + // Writer returns the ResponseWriter. + Writer() ResponseWriter + // SetWriter sets the ResponseWriter. + SetWriter(w ResponseWriter) + // Path returns the registered path for the handler. + Path() string + // Params returns a Params slice containing the matched + // wildcard parameters. + Params() Params + // Param retrieve a matching wildcard parameter by name. + Param(name string) string + // QueryParams parses the Request RawQuery and returns the corresponding values. + QueryParams() url.Values + // QueryParam returns the first query value associated with the given key. + QueryParam(name string) string + // String sends a formatted string with the specified status code. + String(code int, format string, values ...any) error + // Blob sends a byte slice with the specified status code and content type. + Blob(code int, contentType string, buf []byte) error + // Stream sends data from an io.Reader with the specified status code and content type. + Stream(code int, contentType string, r io.Reader) error + // Redirect sends an HTTP redirect response with the given status code and URL. + Redirect(code int, url string) error + // Clone returns a copy of the Context that is safe to use after the HandlerFunc returns. + Clone() Context + // Tree is a local copy of the Tree in use to serve the request. + Tree() *Tree + // Fox returns the Router in use to serve the request. + Fox() *Router +} + +// context holds request-related information and allows interaction with the ResponseWriter. +type context struct { + w ResponseWriter + req *http.Request + params *Params + skipNds *skippedNodes + + // tree at allocation (read-only, no reset) + tree *Tree + fox *Router + + cachedQuery url.Values + path string + rec recorder +} + +func (c *context) reset(fox *Router, w http.ResponseWriter, r *http.Request) { + c.rec.reset(w) + c.req = r + if r.ProtoMajor == 2 { + c.w = h2Writer{&c.rec} + } else { + c.w = h1Writer{&c.rec} + } + c.fox = fox + c.path = "" + c.cachedQuery = nil + *c.params = (*c.params)[:0] +} + +func (c *context) resetNil() { + c.req = nil + c.w = nil + c.fox = nil + c.path = "" + c.cachedQuery = nil + *c.params = (*c.params)[:0] +} + +// Request returns the *http.Request. +func (c *context) Request() *http.Request { + return c.req +} + +// SetRequest sets the *http.Request. +func (c *context) SetRequest(r *http.Request) { + c.req = r +} + +// Writer returns the ResponseWriter. +func (c *context) Writer() ResponseWriter { + return c.w +} + +// SetWriter sets the ResponseWriter. +func (c *context) SetWriter(w ResponseWriter) { + c.w = w +} + +// Done returns a channel that closes when the request's context is +// cancelled or times out. +func (c *context) Done() <-chan struct{} { + return c.req.Context().Done() +} + +// Params returns a Params slice containing the matched +// wildcard parameters. +func (c *context) Params() Params { + return *c.params +} + +// Param retrieve a matching wildcard segment by name. +// It's a helper for c.Params.Get(name). +func (c *context) Param(name string) string { + for _, p := range c.Params() { + if p.Key == name { + return p.Value + } + } + return "" +} + +// QueryParams parses RawQuery and returns the corresponding values. +// It's a helper for c.Request.URL.Query(). Note that the parsed +// result is cached. +func (c *context) QueryParams() url.Values { + c.req.URL.Query() + return c.getQueries() +} + +// QueryParam returns the first value associated with the given key. +// It's a helper for c.QueryParams().Get(name). +func (c *context) QueryParam(name string) string { + return c.getQueries().Get(name) +} + +// Path returns the registered path for the handler. +func (c *context) Path() string { + return c.path +} + +// String sends a formatted string with the specified status code. +func (c *context) String(code int, format string, values ...any) (err error) { + c.w.Header().Set(HeaderContentType, MIMETextPlainCharsetUTF8) + c.w.WriteHeader(code) + _, err = fmt.Fprintf(c.w, format, values...) + return +} + +// Blob sends a byte slice with the specified status code and content type. +func (c *context) Blob(code int, contentType string, buf []byte) (err error) { + c.w.Header().Set(HeaderContentType, contentType) + c.w.WriteHeader(code) + _, err = c.w.Write(buf) + return +} + +// Stream sends data from an io.Reader with the specified status code and content type. +func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { + c.w.Header().Set(HeaderContentType, contentType) + c.w.WriteHeader(code) + _, err = io.Copy(c.w, r) + return +} + +// Redirect sends an HTTP redirect response with the given status code and URL. +func (c *context) Redirect(code int, url string) error { + if code < http.StatusMultipleChoices || code > http.StatusPermanentRedirect { + return ErrInvalidRedirectCode + } + http.Redirect(c.w, c.req, url, code) + return nil +} + +// Tree is a local copy of the Tree in use to serve the request. +func (c *context) Tree() *Tree { + return c.tree +} + +// Fox returns the Router in use to serve the request. +func (c *context) Fox() *Router { + return c.fox +} + +// Clone returns a copy of the Context that is safe to use after the HandlerFunc returns. +func (c *context) Clone() Context { + cp := context{ + rec: c.rec, + req: c.req.Clone(c.req.Context()), + fox: c.fox, + tree: c.tree, + } + cp.rec.ResponseWriter = noopWriter{} + cp.w = &cp.rec + params := make(Params, len(*c.params)) + copy(params, *c.params) + cp.params = ¶ms + cp.cachedQuery = nil + return &cp +} + +// Close releases the context to be reused later. +func (c *context) Close() { + // Put back the context, if not extended more than max params or max depth, allowing + // the slice to naturally grow within the constraint. + if cap(*c.params) > int(c.tree.maxParams.Load()) || cap(*c.skipNds) > int(c.tree.maxDepth.Load()) { + return + } + c.tree.ctx.Put(c) +} + +func (c *context) getQueries() url.Values { + if c.cachedQuery == nil { + if c.req != nil { + c.cachedQuery = c.req.URL.Query() + } else { + c.cachedQuery = url.Values{} + } + } + return c.cachedQuery +} + +// WrapF is an adapter for wrapping http.HandlerFunc and returns a HandlerFunc function. +func WrapF(f http.HandlerFunc) HandlerFunc { + return func(c Context) { + f.ServeHTTP(c.Writer(), c.Request()) + } +} + +// WrapH is an adapter for wrapping http.Handler and returns a HandlerFunc function. +func WrapH(h http.Handler) HandlerFunc { + return func(c Context) { + h.ServeHTTP(c.Writer(), c.Request()) + } +} + +// WrapM is an adapter for wrapping http.Handler middleware and returns a +// MiddlewareFunc function. +func WrapM(m func(handler http.Handler) http.Handler) MiddlewareFunc { + return func(next HandlerFunc) HandlerFunc { + return func(c Context) { + adapter := m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next(c) + })) + adapter.ServeHTTP(c.Writer(), c.Request()) + } + } +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..667ea7a --- /dev/null +++ b/context_test.go @@ -0,0 +1,147 @@ +package fox + +import ( + "bytes" + netcontext "context" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +func TestContext_QueryParams(t *testing.T) { + wantValues := url.Values{ + "a": []string{"b"}, + "c": []string{"d", "e"}, + } + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + req.URL.RawQuery = wantValues.Encode() + + c := newTestContextTree(New().Tree()) + c.req = req + values := c.QueryParams() + assert.Equal(t, wantValues, values) + assert.Equal(t, wantValues, c.cachedQuery) +} + +func TestContext_QueryParam(t *testing.T) { + wantValues := url.Values{ + "a": []string{"b"}, + "c": []string{"d", "e"}, + } + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + req.URL.RawQuery = wantValues.Encode() + + c := newTestContextTree(New().Tree()) + c.req = req + assert.Equal(t, "b", c.QueryParam("a")) + assert.Equal(t, "d", c.QueryParam("c")) + assert.Equal(t, wantValues, c.cachedQuery) +} + +func TestContext_Clone(t *testing.T) { + wantValues := url.Values{ + "a": []string{"b"}, + "c": []string{"d", "e"}, + } + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + req.URL.RawQuery = wantValues.Encode() + + c := newTextContextOnly(New(), httptest.NewRecorder(), req) + + buf := []byte("foo bar") + _, err := c.w.Write(buf) + require.NoError(t, err) + + cc := c.Clone() + assert.Equal(t, http.StatusOK, cc.Writer().Status()) + assert.Equal(t, len(buf), cc.Writer().Size()) + assert.Equal(t, wantValues, c.QueryParams()) + _, err = cc.Writer().Write([]byte("invalid")) + assert.ErrorIs(t, err, ErrDiscardedResponseWriter) +} + +func TestContext_Done(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + ctx, cancel := netcontext.WithCancel(netcontext.Background()) + cancel() + req = req.WithContext(ctx) + _, c := NewTestContext(httptest.NewRecorder(), req) + select { + case <-c.Done(): + require.ErrorIs(t, c.Request().Context().Err(), netcontext.Canceled) + case <-time.After(1): + t.FailNow() + } +} + +func TestContext_Redirect(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + require.NoError(t, c.Redirect(http.StatusTemporaryRedirect, "https://example.com/foo/bar")) + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://example.com/foo/bar", w.Header().Get(HeaderLocation)) +} + +func TestContext_Blob(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + buf := []byte("foobar") + require.NoError(t, c.Blob(http.StatusCreated, MIMETextPlain, buf)) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, MIMETextPlain, w.Header().Get(HeaderContentType)) + assert.Equal(t, buf, w.Body.Bytes()) + assert.Equal(t, len(buf), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} + +func TestContext_Stream(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + buf := []byte("foobar") + require.NoError(t, c.Stream(http.StatusCreated, MIMETextPlain, bytes.NewBuffer(buf))) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, MIMETextPlain, w.Header().Get(HeaderContentType)) + assert.Equal(t, buf, w.Body.Bytes()) + assert.Equal(t, len(buf), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} + +func TestContext_String(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + s := "foobar" + require.NoError(t, c.String(http.StatusCreated, s)) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, MIMETextPlainCharsetUTF8, w.Header().Get(HeaderContentType)) + assert.Equal(t, s, w.Body.String()) + assert.Equal(t, len(s), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} + +func TestContext_Writer(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + buf := []byte("foobar") + c.Writer().WriteHeader(http.StatusCreated) + assert.Equal(t, 0, c.Writer().Size()) + n, err := c.Writer().Write(buf) + require.NoError(t, err) + assert.Equal(t, len(buf), n) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, buf, w.Body.Bytes()) + assert.Equal(t, len(buf), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} diff --git a/error.go b/error.go index 59deb8b..efcaa28 100644 --- a/error.go +++ b/error.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -7,17 +11,20 @@ import ( ) var ( - ErrRouteNotFound = errors.New("route not found") - ErrRouteExist = errors.New("route already registered") - ErrRouteConflict = errors.New("route conflict") - ErrInvalidRoute = errors.New("invalid route") + ErrRouteNotFound = errors.New("route not found") + ErrRouteExist = errors.New("route already registered") + ErrRouteConflict = errors.New("route conflict") + ErrInvalidRoute = errors.New("invalid route") + ErrDiscardedResponseWriter = errors.New("discarded response writer") + ErrInvalidRedirectCode = errors.New("invalid redirect code") ) type RouteConflictError struct { - err error - Method string - Path string - Matched []string + err error + Method string + Path string + Matched []string + isUpdate bool } func newConflictErr(method, path, catchAllKey string, matched []string) *RouteConflictError { @@ -33,8 +40,19 @@ func newConflictErr(method, path, catchAllKey string, matched []string) *RouteCo } func (e *RouteConflictError) Error() string { - path := e.Path - return fmt.Sprintf("new route [%s] %s conflicts with %s", e.Method, path, strings.Join(e.Matched, ", ")) + if !e.isUpdate { + return e.insertError() + } + return e.updateError() +} + +func (e *RouteConflictError) insertError() string { + return fmt.Sprintf("%s: new route [%s] %s conflicts with %s", e.err, e.Method, e.Path, strings.Join(e.Matched, ", ")) +} + +func (e *RouteConflictError) updateError() string { + return fmt.Sprintf("wildcard conflict: updated route [%s] %s conflicts with %s", e.Method, e.Path, strings.Join(e.Matched, ", ")) + } func (e *RouteConflictError) Unwrap() error { diff --git a/router.go b/fox.go similarity index 65% rename from router.go rename to fox.go index 2453360..a9fa19b 100644 --- a/router.go +++ b/fox.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -13,71 +17,51 @@ const verb = 4 var commonVerbs = [verb]string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete} -// Handler respond to an HTTP request. +// HandlerFunc is a function type that responds to an HTTP request. +// It enforces the same contract as http.Handler but provides additional feature +// like matched wildcard route segments via the Context type. The Context is freed once +// the HandlerFunc returns and may be reused later to save resources. If you need +// to hold the context longer, you have to copy it (see Clone method). // -// This interface enforce the same contract as http.Handler except that matched wildcard route segment -// are accessible via params. Params slice is freed once ServeHTTP returns and may be reused later to -// save resource. Therefore, if you need to hold params slice longer, you have to copy it (see Clone method). +// Similar to http.Handler, to abort a HandlerFunc so the client sees an interrupted +// response, panic with the value http.ErrAbortHandler. // -// As for http.Handler interface, to abort a handler so the client sees an interrupted response, panic with -// the value http.ErrAbortHandler. -type Handler interface { - ServeHTTP(http.ResponseWriter, *http.Request, Params) -} +// HandlerFunc functions should be thread-safe, as they will be called concurrently. +type HandlerFunc func(c Context) -// HandlerFunc is an adapter to allow the use of ordinary functions as HTTP handlers. If f is a function with the -// appropriate signature, HandlerFunc(f) is a Handler that calls f. -type HandlerFunc func(http.ResponseWriter, *http.Request, Params) - -// ServerHTTP calls f(w, r, params) -func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, params Params) { - f(w, r, params) -} +// MiddlewareFunc is a function type for implementing HandlerFunc middleware. +// The returned HandlerFunc usually wraps the input HandlerFunc, allowing you to perform operations +// before and/or after the wrapped HandlerFunc is executed. MiddlewareFunc functions should +// be thread-safe, as they will be called concurrently. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc // Router is a lightweight high performance HTTP request router that support mutation on its routing tree // while handling request concurrently. type Router struct { - // User-configurable http.Handler which is called when no matching route is found. - // By default, http.NotFound is used. - notFound http.Handler - - // User-configurable http.Handler which is called when the request cannot be routed, - // but the same route exist for other methods. The "Allow" header it automatically set - // before calling the handler. Set HandleMethodNotAllowed to true to enable this option. By default, - // http.Error with http.StatusMethodNotAllowed is used. - methodNotAllowed http.Handler - - // Register a function to handle panics recovered from http handlers. - panicHandler func(http.ResponseWriter, *http.Request, interface{}) - - tree atomic.Pointer[Tree] - - // If enabled, fox return a 405 Method Not Allowed instead of 404 Not Found when the route exist for another http verb. + noRoute HandlerFunc + noMethod HandlerFunc + tree atomic.Pointer[Tree] + mws []MiddlewareFunc handleMethodNotAllowed bool - - // Enable automatic redirection fallback when the current request does not match but another handler is found - // after cleaning up superfluous path elements (see CleanPath). E.g. /../foo/bar request does not match but /foo/bar would. - // The client is redirected with a http status code 301 for GET requests and 308 for all other methods. - redirectFixedPath bool - - // Enable automatic redirection fallback when the current request does not match but another handler is found - // with/without an additional trailing slash. E.g. /foo/bar/ request does not match but /foo/bar would match. - // The client is redirected with a http status code 301 for GET requests and 308 for all other methods. - redirectTrailingSlash bool - - // If enabled, the matched route will be accessible as a Handler parameter. - // Usage: p.Get(fox.RouteKey) - saveMatchedRoute bool + redirectFixedPath bool + redirectTrailingSlash bool } var _ http.Handler = (*Router)(nil) -// New returns a ready to use Router. +// New returns a ready to use instance of Fox router. func New(opts ...Option) *Router { r := new(Router) + + r.noRoute = func(c Context) { http.Error(c.Writer(), "404 page not found", http.StatusNotFound) } + r.noMethod = func(c Context) { + http.Error(c.Writer(), http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + for _, opt := range opts { opt.apply(r) } + r.tree.Store(r.NewTree()) return r } @@ -89,7 +73,8 @@ func New(opts ...Option) *Router { // This api is EXPERIMENTAL and is likely to change in future release. func (fox *Router) NewTree() *Tree { tree := new(Tree) - tree.saveRoute = fox.saveMatchedRoute + tree.mws = fox.mws + // Pre instantiate nodes for common http verb nds := make([]*node, len(commonVerbs)) for i := range commonVerbs { @@ -99,39 +84,54 @@ func (fox *Router) NewTree() *Tree { } tree.nodes.Store(&nds) - tree.pp = sync.Pool{ + tree.ctx = sync.Pool{ New: func() any { - params := make(Params, 0, tree.maxParams.Load()) - return ¶ms - }, - } - - tree.np = sync.Pool{ - New: func() any { - skpNds := make(skippedNodes, 0, tree.maxDepth.Load()) - return &skpNds + return tree.allocateContext() }, } return tree } -// Handler registers a new handler for the given method and path. This function return an error if the route +// Tree atomically loads and return the currently in-use routing tree. +// This API is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) Tree() *Tree { + return fox.tree.Load() +} + +// Swap atomically replaces the currently in-use routing tree with the provided new tree, and returns the previous tree. +// This API is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) Swap(new *Tree) (old *Tree) { + return fox.tree.Swap(new) +} + +// Handle registers a new handler for the given method and path. This function return an error if the route // is already registered or conflict with another. It's perfectly safe to add a new handler while the tree is in use // for serving requests. This function is safe for concurrent use by multiple goroutine. // To override an existing route, use Update. -func (fox *Router) Handler(method, path string, handler Handler) error { +func (fox *Router) Handle(method, path string, handler HandlerFunc) error { t := fox.Tree() t.Lock() defer t.Unlock() - return t.Handler(method, path, handler) + return t.Handle(method, path, handler) +} + +// MustHandle registers a new handler for the given method and path. This function is a convenience +// wrapper for the Handle function. It will panic if the route is already registered or conflicts +// with another route. It's perfectly safe to add a new handler while the tree is in use for serving +// requests. This function is safe for concurrent use by multiple goroutines. +// To override an existing route, use Update. +func (fox *Router) MustHandle(method, path string, handler HandlerFunc) { + if err := fox.Handle(method, path, handler); err != nil { + panic(err) + } } // Update override an existing handler for the given method and path. If the route does not exist, // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. This function is safe for concurrent use by multiple goroutine. -// To add new handler, use Handler method. -func (fox *Router) Update(method, path string, handler Handler) error { +// To add new handler, use Handle method. +func (fox *Router) Update(method, path string, handler HandlerFunc) error { t := fox.Tree() t.Lock() defer t.Unlock() @@ -148,53 +148,20 @@ func (fox *Router) Remove(method, path string) error { return t.Remove(method, path) } -// Tree atomically loads and return the currently in-use routing tree. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Tree() *Tree { - return fox.tree.Load() -} - -// Swap atomically replaces the currently in-use routing tree with the provided new tree, and returns the previous tree. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Swap(new *Tree) (old *Tree) { - return fox.tree.Swap(new) -} - -// Use atomically replaces the currently in-use routing tree with the provided new tree. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Use(new *Tree) { - fox.tree.Store(new) -} - -// Lookup allow to do manual lookup of a route and return the matched handler along with parsed params and -// trailing slash redirect recommendation. Note that you should always free Params if NOT nil by calling -// params.Free(t). If lazy is set to true, route params are not parsed. This function is safe for concurrent use -// by multiple goroutine and while mutation on Tree are ongoing. -func Lookup(t *Tree, method, path string, lazy bool) (handler Handler, params *Params, tsr bool) { - nds := t.load() - index := findRootNode(method, nds) - if index < 0 { - return nil, nil, false - } - - n, ps, tsr := t.lookup(nds[index], path, lazy) - if n != nil { - return n.handler, ps, tsr - } - return nil, nil, tsr -} - // Has allows to check if the given method and path exactly match a registered route. This function is safe for // concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This api is EXPERIMENTAL and is likely to change in future release. func Has(t *Tree, method, path string) bool { - nds := t.load() + nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { return false } - n, _, _ := t.lookup(nds[index], path, true) + c := t.ctx.Get().(*context) + c.resetNil() + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() return n != nil && n.path == path } @@ -202,12 +169,16 @@ func Has(t *Tree, method, path string) bool { // This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This api is EXPERIMENTAL and is likely to change in future release. func Reverse(t *Tree, method, path string) string { - nds := t.load() + nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { return "" } - n, _, _ := t.lookup(nds[index], path, true) + + c := t.ctx.Get().(*context) + c.resetNil() + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() if n == nil { return "" } @@ -219,15 +190,15 @@ func Reverse(t *Tree, method, path string) string { var SkipMethod = errors.New("skip method") // WalkFunc is the type of the function called by Walk to visit each registered routes. -type WalkFunc func(method, path string, handler Handler) error +type WalkFunc func(method, path string, handler HandlerFunc) error // Walk allow to walk over all registered route in lexicographical order. If the function // return the special value SkipMethod, Walk skips the current method. This function is // safe for concurrent use by multiple goroutine and while mutation are ongoing. // This api is EXPERIMENTAL and is likely to change in future release. func Walk(tree *Tree, fn WalkFunc) error { - nds := tree.load() -NEXT: + nds := *tree.nodes.Load() +Next: for i := range nds { method := nds[i].key it := newRawIterator(nds[i]) @@ -235,7 +206,7 @@ NEXT: err := fn(method, it.path, it.current.handler) if err != nil { if errors.Is(err, SkipMethod) { - continue NEXT + continue Next } return err } @@ -246,34 +217,38 @@ NEXT: } func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if fox.panicHandler != nil { - defer fox.recover(w, r) - } var ( - n *node - params *Params - tsr bool + n *node + tsr bool ) tree := fox.Tree() - nds := tree.load() + + c := tree.ctx.Get().(*context) + c.reset(fox, w, r) + + nds := *tree.nodes.Load() index := findRootNode(r.Method, nds) if index < 0 { - goto NO_METHOD_FALLBACK + goto NoMethodFallback } - n, params, tsr = tree.lookup(nds[index], r.URL.Path, false) + n, tsr = tree.lookup(nds[index], r.URL.Path, c.params, c.skipNds, false) if n != nil { - if params != nil { - n.handler.ServeHTTP(w, r, *params) - params.Free(tree) - return + c.path = n.path + n.handler(c) + // Put back the context, if not extended more than max params or max depth, allowing + // the slice to naturally grow within the constraint. + if cap(*c.params) <= int(tree.maxParams.Load()) && cap(*c.skipNds) <= int(tree.maxDepth.Load()) { + c.tree.ctx.Put(c) } - n.handler.ServeHTTP(w, r, nil) return } + // Reset params as it may have recorded wildcard segment + *c.params = (*c.params)[:0] + if r.Method != http.MethodConnect && r.URL.Path != "/" { code := http.StatusMovedPermanently @@ -285,32 +260,35 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if tsr && fox.redirectTrailingSlash { r.URL.Path = fixTrailingSlash(r.URL.Path) http.Redirect(w, r, r.URL.String(), code) + c.Close() return } if fox.redirectFixedPath { cleanedPath := CleanPath(r.URL.Path) - n, _, tsr := tree.lookup(nds[index], cleanedPath, true) + n, tsr := tree.lookup(nds[index], cleanedPath, c.params, c.skipNds, true) if n != nil { r.URL.Path = cleanedPath http.Redirect(w, r, r.URL.String(), code) + c.Close() return } if tsr && fox.redirectTrailingSlash { r.URL.Path = fixTrailingSlash(cleanedPath) http.Redirect(w, r, r.URL.String(), code) + c.Close() return } } } -NO_METHOD_FALLBACK: +NoMethodFallback: if fox.handleMethodNotAllowed { var sb strings.Builder for i := 0; i < len(nds); i++ { if nds[i].key != r.Method { - if n, _, _ := tree.lookup(nds[i], r.URL.Path, true); n != nil { + if n, _ := tree.lookup(nds[i], r.URL.Path, c.params, c.skipNds, true); n != nil { if sb.Len() > 0 { sb.WriteString(", ") } @@ -321,29 +299,14 @@ NO_METHOD_FALLBACK: allowed := sb.String() if allowed != "" { w.Header().Set("Allow", allowed) - if fox.methodNotAllowed != nil { - fox.methodNotAllowed.ServeHTTP(w, r) - return - } - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + fox.noMethod(c) + c.Close() return } } - if fox.notFound != nil { - fox.notFound.ServeHTTP(w, r) - return - } - http.NotFound(w, r) -} - -func (fox *Router) recover(w http.ResponseWriter, r *http.Request) { - if val := recover(); val != nil { - if abortErr, ok := val.(error); ok && errors.Is(abortErr, http.ErrAbortHandler) { - panic(abortErr) - } - fox.panicHandler(w, r, val) - } + fox.noRoute(c) + c.Close() } type resultType int @@ -445,7 +408,7 @@ const ( func parseRoute(path string) (string, string, int, error) { if !strings.HasPrefix(path, "/") { - return "", "", -1, fmt.Errorf("path must start with '/': %w", ErrInvalidRoute) + return "", "", -1, fmt.Errorf("%w: path must start with '/'", ErrInvalidRoute) } state := stateDefault diff --git a/router_test.go b/fox_test.go similarity index 67% rename from router_test.go rename to fox_test.go index d671088..5a6e50a 100644 --- a/router_test.go +++ b/fox_test.go @@ -1,8 +1,11 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( "fmt" - "github.com/gin-gonic/gin" fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,7 +21,9 @@ import ( "time" ) -var emptyHandler = HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) {}) +var emptyHandler = HandlerFunc(func(c Context) {}) +var pathHanlder = HandlerFunc(func(c Context) { _, _ = c.Writer().Write([]byte(c.Request().URL.Path)) }) +var routeHandler = HandlerFunc(func(c Context) { _, _ = c.Writer().Write([]byte(c.Path())) }) type mockResponseWriter struct{} @@ -41,212 +46,12 @@ type route struct { path string } -var ginOverlappingRoutes = []route{ - {"GET", "/foo/abc/id_:id/xyz"}, - {"GET", "/foo/:name/id_:id/:name"}, - {"GET", "/foo/:name/id_:id/xyz"}, -} - var overlappingRoutes = []route{ {"GET", "/foo/abc/id:{id}/xyz"}, {"GET", "/foo/{name}/id:{id}/{name}"}, {"GET", "/foo/{name}/id:{id}/xyz"}, } -var ginGithubRoutes = []route{ - {"GET", "/repos/:owner/:repo/subscription"}, - {"PUT", "/repos/:owner/:repo/subscription"}, - {"DELETE", "/repos/:owner/:repo/subscription"}, - {"GET", "/user/subscriptions/:owner/:repo"}, - {"PUT", "/user/subscriptions/:owner/:repo"}, - {"DELETE", "/user/subscriptions/:owner/:repo"}, - - // Gists - {"GET", "/users/:user/gists"}, - {"GET", "/gists"}, - {"GET", "/gists/:id"}, - {"POST", "/gists"}, - {"PUT", "/gists/:id/star"}, - {"DELETE", "/gists/:id/star"}, - {"GET", "/gists/:id/star"}, - {"POST", "/gists/:id/forks"}, - {"DELETE", "/gists/:id"}, - - // Git Data - {"GET", "/repos/:owner/:repo/git/blobs/:sha"}, - {"POST", "/repos/:owner/:repo/git/blobs"}, - {"GET", "/repos/:owner/:repo/git/commits/:sha"}, - {"POST", "/repos/:owner/:repo/git/commits"}, - {"GET", "/repos/:owner/:repo/git/refs/*ref"}, - {"GET", "/repos/:owner/:repo/git/refs"}, - {"POST", "/repos/:owner/:repo/git/refs"}, - {"DELETE", "/repos/:owner/:repo/git/refs/*ref"}, - {"GET", "/repos/:owner/:repo/git/tags/:sha"}, - {"POST", "/repos/:owner/:repo/git/tags"}, - {"GET", "/repos/:owner/:repo/git/trees/:sha"}, - {"POST", "/repos/:owner/:repo/git/trees"}, - - // Issues - {"GET", "/issues"}, - {"GET", "/user/issues"}, - {"GET", "/orgs/:org/issues"}, - {"GET", "/repos/:owner/:repo/issues"}, - {"GET", "/repos/:owner/:repo/issues/:number"}, - {"POST", "/repos/:owner/:repo/issues"}, - {"GET", "/repos/:owner/:repo/assignees"}, - {"GET", "/repos/:owner/:repo/assignees/:assignee"}, - {"GET", "/repos/:owner/:repo/issues/:number/comments"}, - {"POST", "/repos/:owner/:repo/issues/:number/comments"}, - {"GET", "/repos/:owner/:repo/issues/:number/events"}, - {"GET", "/repos/:owner/:repo/labels"}, - {"GET", "/repos/:owner/:repo/labels/:name"}, - {"POST", "/repos/:owner/:repo/labels"}, - {"DELETE", "/repos/:owner/:repo/labels/:name"}, - {"GET", "/repos/:owner/:repo/issues/:number/labels"}, - {"POST", "/repos/:owner/:repo/issues/:number/labels"}, - {"DELETE", "/repos/:owner/:repo/issues/:number/labels/:name"}, - {"PUT", "/repos/:owner/:repo/issues/:number/labels"}, - {"DELETE", "/repos/:owner/:repo/issues/:number/labels"}, - {"GET", "/repos/:owner/:repo/milestones/:number/labels"}, - {"GET", "/repos/:owner/:repo/milestones"}, - {"GET", "/repos/:owner/:repo/milestones/:number"}, - {"POST", "/repos/:owner/:repo/milestones"}, - {"DELETE", "/repos/:owner/:repo/milestones/:number"}, - - // Miscellaneous - {"GET", "/emojis"}, - {"GET", "/gitignore/templates"}, - {"GET", "/gitignore/templates/:name"}, - {"POST", "/markdown"}, - {"POST", "/markdown/raw"}, - {"GET", "/meta"}, - {"GET", "/rate_limit"}, - - // Organizations - {"GET", "/users/:user/orgs"}, - {"GET", "/user/orgs"}, - {"GET", "/orgs/:org"}, - {"GET", "/orgs/:org/members"}, - {"GET", "/orgs/:org/members/:user"}, - {"DELETE", "/orgs/:org/members/:user"}, - {"GET", "/orgs/:org/public_members"}, - {"GET", "/orgs/:org/public_members/:user"}, - {"PUT", "/orgs/:org/public_members/:user"}, - {"DELETE", "/orgs/:org/public_members/:user"}, - {"GET", "/orgs/:org/teams"}, - {"GET", "/teams/:id"}, - {"POST", "/orgs/:org/teams"}, - {"DELETE", "/teams/:id"}, - {"GET", "/teams/:id/members"}, - {"GET", "/teams/:id/members/:user"}, - {"PUT", "/teams/:id/members/:user"}, - {"DELETE", "/teams/:id/members/:user"}, - {"GET", "/teams/:id/repos"}, - {"GET", "/teams/:id/repos/:owner/:repo"}, - {"PUT", "/teams/:id/repos/:owner/:repo"}, - {"DELETE", "/teams/:id/repos/:owner/:repo"}, - {"GET", "/user/teams"}, - - // Pull Requests - {"GET", "/repos/:owner/:repo/pulls"}, - {"GET", "/repos/:owner/:repo/pulls/:number"}, - {"POST", "/repos/:owner/:repo/pulls"}, - {"GET", "/repos/:owner/:repo/pulls/:number/commits"}, - {"GET", "/repos/:owner/:repo/pulls/:number/files"}, - {"GET", "/repos/:owner/:repo/pulls/:number/merge"}, - {"PUT", "/repos/:owner/:repo/pulls/:number/merge"}, - {"GET", "/repos/:owner/:repo/pulls/:number/comments"}, - {"PUT", "/repos/:owner/:repo/pulls/:number/comments"}, - - // Repositories - {"GET", "/user/repos"}, - {"GET", "/users/:user/repos"}, - {"GET", "/orgs/:org/repos"}, - {"GET", "/repositories"}, - {"POST", "/user/repos"}, - {"POST", "/orgs/:org/repos"}, - {"GET", "/repos/:owner/:repo"}, - {"GET", "/repos/:owner/:repo/contributors"}, - {"GET", "/repos/:owner/:repo/languages"}, - {"GET", "/repos/:owner/:repo/teams"}, - {"GET", "/repos/:owner/:repo/tags"}, - {"GET", "/repos/:owner/:repo/branches"}, - {"GET", "/repos/:owner/:repo/branches/:branch"}, - {"DELETE", "/repos/:owner/:repo"}, - {"GET", "/repos/:owner/:repo/collaborators"}, - {"GET", "/repos/:owner/:repo/collaborators/:user"}, - {"PUT", "/repos/:owner/:repo/collaborators/:user"}, - {"DELETE", "/repos/:owner/:repo/collaborators/:user"}, - {"GET", "/repos/:owner/:repo/comments"}, - {"GET", "/repos/:owner/:repo/commits/:sha/comments"}, - {"POST", "/repos/:owner/:repo/commits/:sha/comments"}, - {"GET", "/repos/:owner/:repo/comments/:id"}, - {"DELETE", "/repos/:owner/:repo/comments/:id"}, - {"GET", "/repos/:owner/:repo/commits"}, - {"GET", "/repos/:owner/:repo/commits/:sha"}, - {"GET", "/repos/:owner/:repo/readme"}, - {"GET", "/repos/:owner/:repo/contents/*path"}, - {"DELETE", "/repos/:owner/:repo/contents/*path"}, - {"GET", "/repos/:owner/:repo/keys"}, - {"GET", "/repos/:owner/:repo/keys/:id"}, - {"POST", "/repos/:owner/:repo/keys"}, - {"DELETE", "/repos/:owner/:repo/keys/:id"}, - {"GET", "/repos/:owner/:repo/downloads"}, - {"GET", "/repos/:owner/:repo/downloads/:id"}, - {"DELETE", "/repos/:owner/:repo/downloads/:id"}, - {"GET", "/repos/:owner/:repo/forks"}, - {"POST", "/repos/:owner/:repo/forks"}, - {"GET", "/repos/:owner/:repo/hooks"}, - {"GET", "/repos/:owner/:repo/hooks/:id"}, - {"POST", "/repos/:owner/:repo/hooks"}, - {"POST", "/repos/:owner/:repo/hooks/:id/tests"}, - {"DELETE", "/repos/:owner/:repo/hooks/:id"}, - {"POST", "/repos/:owner/:repo/merges"}, - {"GET", "/repos/:owner/:repo/releases"}, - {"GET", "/repos/:owner/:repo/releases/:id"}, - {"POST", "/repos/:owner/:repo/releases"}, - {"DELETE", "/repos/:owner/:repo/releases/:id"}, - {"GET", "/repos/:owner/:repo/releases/:id/assets"}, - {"GET", "/repos/:owner/:repo/stats/contributors"}, - {"GET", "/repos/:owner/:repo/stats/commit_activity"}, - {"GET", "/repos/:owner/:repo/stats/code_frequency"}, - {"GET", "/repos/:owner/:repo/stats/participation"}, - {"GET", "/repos/:owner/:repo/stats/punch_card"}, - {"GET", "/repos/:owner/:repo/statuses/:ref"}, - {"POST", "/repos/:owner/:repo/statuses/:ref"}, - - // Search - {"GET", "/search/repositories"}, - {"GET", "/search/code"}, - {"GET", "/search/issues"}, - {"GET", "/search/users"}, - {"GET", "/legacy/issues/search/:owner/:repository/:state/:keyword"}, - {"GET", "/legacy/repos/search/:keyword"}, - {"GET", "/legacy/user/search/:keyword"}, - {"GET", "/legacy/user/email/:email"}, - - // Users - {"GET", "/users/:user"}, - {"GET", "/user"}, - {"GET", "/users"}, - {"GET", "/user/emails"}, - {"POST", "/user/emails"}, - {"DELETE", "/user/emails"}, - {"GET", "/users/:user/followers"}, - {"GET", "/user/followers"}, - {"GET", "/users/:user/following"}, - {"GET", "/user/following"}, - {"GET", "/user/following/:user"}, - {"GET", "/users/:user/following/:target_user"}, - {"PUT", "/user/following/:user"}, - {"DELETE", "/user/following/:user"}, - {"GET", "/users/:user/keys"}, - {"GET", "/user/keys"}, - {"GET", "/user/keys/:id"}, - {"POST", "/user/keys"}, - {"DELETE", "/user/keys/:id"}, -} - // From https://github.com/julienschmidt/go-http-routing-benchmark var staticRoutes = []route{ {"GET", "/"}, @@ -677,64 +482,16 @@ func benchRouteParallel(b *testing.B, router http.Handler, rte route) { func BenchmarkStaticAll(b *testing.B) { r := New() for _, route := range staticRoutes { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(w http.ResponseWriter, r *http.Request, p Params) {}))) - } - - benchRoutes(b, r, staticRoutes) -} - -func BenchmarkGinStaticAll(b *testing.B) { - gin.SetMode(gin.ReleaseMode) - r := gin.New() - for _, route := range staticRoutes { - r.GET(route.path, func(context *gin.Context) {}) + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } benchRoutes(b, r, staticRoutes) } -func BenchmarkLookup(b *testing.B) { - r := New() - for _, route := range staticRoutes { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(w http.ResponseWriter, r *http.Request, p Params) {}))) - } - - b.ReportAllocs() - b.ResetTimer() - - tree := r.Tree() - for i := 0; i < b.N; i++ { - for _, route := range staticRoutes { - _, p, _ := Lookup(tree, route.method, route.path, false) - if p != nil { - p.Free(tree) - } - } - } -} - func BenchmarkGithubParamsAll(b *testing.B) { r := New() for _, route := range githubAPI { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(w http.ResponseWriter, r *http.Request, p Params) {}))) - } - - req := httptest.NewRequest("GET", "/repos/sylvain/fox/hooks/1500", nil) - w := new(mockResponseWriter) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.ServeHTTP(w, req) - } -} - -func BenchmarkGinGithubParamsAll(b *testing.B) { - gin.SetMode(gin.ReleaseMode) - r := gin.New() - for _, route := range ginGithubRoutes { - r.Handle(route.method, route.path, func(context *gin.Context) {}) + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } req := httptest.NewRequest("GET", "/repos/sylvain/fox/hooks/1500", nil) @@ -751,7 +508,7 @@ func BenchmarkGinGithubParamsAll(b *testing.B) { func BenchmarkOverlappingRoute(b *testing.B) { r := New() for _, route := range overlappingRoutes { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(w http.ResponseWriter, r *http.Request, p Params) {}))) + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } req := httptest.NewRequest("GET", "/foo/abc/id:123/xy", nil) @@ -765,36 +522,17 @@ func BenchmarkOverlappingRoute(b *testing.B) { } } -func BenchmarkGinOverlappingRoute(b *testing.B) { - gin.SetMode(gin.ReleaseMode) - r := gin.New() - - for _, route := range ginOverlappingRoutes { - r.Handle(route.method, route.path, func(context *gin.Context) {}) - } - - req := httptest.NewRequest("GET", "/foo/abc/id_123/xy", nil) - w := new(mockResponseWriter) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.ServeHTTP(w, req) - } -} - func BenchmarkStaticParallel(b *testing.B) { r := New() for _, route := range staticRoutes { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(_ http.ResponseWriter, _ *http.Request, _ Params) {}))) + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } benchRouteParallel(b, r, route{"GET", "/progs/image_package4.out"}) } func BenchmarkCatchAll(b *testing.B) { r := New() - require.NoError(b, r.Tree().Handler(http.MethodGet, "/something/*{args}", HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}))) + require.NoError(b, r.Tree().Handle(http.MethodGet, "/something/*{args}", emptyHandler)) w := new(mockResponseWriter) req := httptest.NewRequest("GET", "/something/awesome", nil) @@ -808,7 +546,7 @@ func BenchmarkCatchAll(b *testing.B) { func BenchmarkCatchAllParallel(b *testing.B) { r := New() - require.NoError(b, r.Tree().Handler(http.MethodGet, "/something/*{args}", HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}))) + require.NoError(b, r.Tree().Handle(http.MethodGet, "/something/*{args}", emptyHandler)) w := new(mockResponseWriter) req := httptest.NewRequest("GET", "/something/awesome", nil) @@ -824,10 +562,9 @@ func BenchmarkCatchAllParallel(b *testing.B) { func TestStaticRoute(t *testing.T) { r := New() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) { _, _ = w.Write([]byte(r.URL.Path)) }) for _, route := range staticRoutes { - require.NoError(t, r.Tree().Handler(route.method, route.path, h)) + require.NoError(t, r.Tree().Handle(route.method, route.path, pathHanlder)) } for _, route := range staticRoutes { @@ -839,11 +576,26 @@ func TestStaticRoute(t *testing.T) { } } +func TestStaticRouteMalloc(t *testing.T) { + r := New() + + for _, route := range staticRoutes { + require.NoError(t, r.Tree().Handle(route.method, route.path, emptyHandler)) + } + + for _, route := range staticRoutes { + req := httptest.NewRequest(route.method, route.path, nil) + w := httptest.NewRecorder() + allocs := testing.AllocsPerRun(100, func() { r.ServeHTTP(w, req) }) + assert.Equal(t, float64(0), allocs) + } +} + func TestParamsRoute(t *testing.T) { - rx := regexp.MustCompile("({|\\*{)[A-z_]+[}]") - r := New(WithSaveMatchedRoute(true)) - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - matches := rx.FindAllString(r.URL.Path, -1) + rx := regexp.MustCompile("({|\\*{)[A-z]+[}]") + r := New() + h := func(c Context) { + matches := rx.FindAllString(c.Request().URL.Path, -1) for _, match := range matches { var key string if strings.HasPrefix(match, "*") { @@ -852,13 +604,13 @@ func TestParamsRoute(t *testing.T) { key = match[1 : len(match)-1] } value := match - assert.Equal(t, value, params.Get(key)) + assert.Equal(t, value, c.Param(key)) } - assert.Equal(t, r.URL.Path, params.Get(RouteKey)) - _, _ = w.Write([]byte(r.URL.Path)) - }) + assert.Equal(t, c.Request().URL.Path, c.Path()) + _, _ = c.Writer().Write([]byte(c.Request().URL.Path)) + } for _, route := range githubAPI { - require.NoError(t, r.Tree().Handler(route.method, route.path, h)) + require.NoError(t, r.Tree().Handle(route.method, route.path, h)) } for _, route := range githubAPI { req := httptest.NewRequest(route.method, route.path, nil) @@ -869,9 +621,34 @@ func TestParamsRoute(t *testing.T) { } } +func TestParamsRouteMalloc(t *testing.T) { + r := New() + for _, route := range githubAPI { + require.NoError(t, r.Tree().Handle(route.method, route.path, emptyHandler)) + } + for _, route := range githubAPI { + req := httptest.NewRequest(route.method, route.path, nil) + w := httptest.NewRecorder() + allocs := testing.AllocsPerRun(100, func() { r.ServeHTTP(w, req) }) + assert.Equal(t, float64(0), allocs) + } +} + +func TestOverlappingRouteMalloc(t *testing.T) { + r := New() + for _, route := range overlappingRoutes { + require.NoError(t, r.Tree().Handle(route.method, route.path, emptyHandler)) + } + for _, route := range overlappingRoutes { + req := httptest.NewRequest(route.method, route.path, nil) + w := httptest.NewRecorder() + allocs := testing.AllocsPerRun(100, func() { r.ServeHTTP(w, req) }) + assert.Equal(t, float64(0), allocs) + } +} + func TestRouterWildcard(t *testing.T) { r := New() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { _, _ = w.Write([]byte(r.URL.Path)) }) routes := []struct { path string @@ -884,7 +661,7 @@ func TestRouterWildcard(t *testing.T) { } for _, route := range routes { - require.NoError(t, r.Tree().Handler(http.MethodGet, route.path, h)) + require.NoError(t, r.Tree().Handle(http.MethodGet, route.path, pathHanlder)) } for _, route := range routes { @@ -915,12 +692,13 @@ func TestRouteWithParams(t *testing.T) { "/info/{user}/project/{project}", } for _, rte := range routes { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) } - nds := tree.load() + nds := *tree.nodes.Load() for _, rte := range routes { - n, _, _ := tree.lookup(nds[0], rte, false) + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], rte, c.params, c.skipNds, false) require.NotNil(t, n) assert.Equal(t, rte, n.path) } @@ -951,22 +729,23 @@ func TestRoutParamEmptySegment(t *testing.T) { } for _, tc := range cases { - require.NoError(t, tree.Handler(http.MethodGet, tc.route, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, tc.route, emptyHandler)) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - nds := tree.load() - n, ps, tsr := tree.lookup(nds[0], tc.path, false) + nds := *tree.nodes.Load() + c := newTestContextTree(tree) + n, tsr := tree.lookup(nds[0], tc.path, c.params, c.skipNds, false) assert.Nil(t, n) - assert.Nil(t, ps) + assert.Empty(t, c.Params()) assert.False(t, tsr) }) } } func TestOverlappingRoute(t *testing.T) { - r := New(WithSaveMatchedRoute(true)) + r := New() cases := []struct { name string path string @@ -981,8 +760,7 @@ func TestOverlappingRoute(t *testing.T) { "/products/{id}", "/products/new", }, - wantMatch: "/products/new", - wantParams: Params{{RouteKey, "/products/new"}}, + wantMatch: "/products/new", }, { name: "basic test less specific", @@ -992,7 +770,7 @@ func TestOverlappingRoute(t *testing.T) { "/products/new", }, wantMatch: "/products/{id}", - wantParams: Params{{Key: "id", Value: "123"}, {RouteKey, "/products/{id}"}}, + wantParams: Params{{Key: "id", Value: "123"}}, }, { name: "ieof+backtrack to {id} wildcard while deleting {a}", @@ -1017,10 +795,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "name", Value: "barr", }, - { - Key: RouteKey, - Value: "/{base}/val1/{id}/new/{name}", - }, }, }, { @@ -1046,10 +820,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "name", Value: "ba", }, - { - Key: RouteKey, - Value: "/{base}/val1/{id}/new/{name}", - }, }, }, { @@ -1075,10 +845,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "name", Value: "bx", }, - { - Key: RouteKey, - Value: "/{base}/val1/{id}/new/{name}", - }, }, }, { @@ -1100,10 +866,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "all", Value: "1/123/new/bar/", }, - { - Key: RouteKey, - Value: "/{base}/val*{all}", - }, }, }, { @@ -1125,10 +887,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "all", Value: "1/123/new", }, - { - Key: RouteKey, - Value: "/{base}/val*{all}", - }, }, }, { @@ -1155,10 +913,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "de", Value: "3", }, - { - Key: RouteKey, - Value: "/foo/{ab}/{bc}/{de}/bar", - }, }, }, { @@ -1189,10 +943,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "fg", Value: "john", }, - { - Key: RouteKey, - Value: "/foo/{ab}/{bc}/{de}/{fg}", - }, }, }, { @@ -1208,10 +958,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "name", Value: "abc", }, - { - Key: RouteKey, - Value: "/foo/{name}/bar", - }, }, }, { @@ -1228,10 +974,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "id", Value: "123", }, - { - Key: RouteKey, - Value: "/foo/{id}", - }, }, }, { @@ -1248,10 +990,6 @@ func TestOverlappingRoute(t *testing.T) { Key: "args", Value: "bc", }, - { - Key: RouteKey, - Value: "/foo/a*{args}", - }, }, }, } @@ -1260,29 +998,34 @@ func TestOverlappingRoute(t *testing.T) { t.Run(tc.name, func(t *testing.T) { tree := r.NewTree() for _, rte := range tc.routes { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) } + nds := *tree.nodes.Load() - nds := tree.load() - n, ps, _ := tree.lookup(nds[0], tc.path, false) + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], tc.path, c.params, c.skipNds, false) require.NotNil(t, n) require.NotNil(t, n.handler) - if ps != nil { - defer ps.Free(tree) - } assert.Equal(t, tc.wantMatch, n.path) if len(tc.wantParams) == 0 { - assert.Nil(t, ps) - return + assert.Empty(t, c.Params()) + } else { + assert.Equal(t, tc.wantParams, c.Params()) } - assert.Equal(t, tc.wantParams, *ps) + // Test with lazy + c = newTestContextTree(tree) + n, _ = tree.lookup(nds[0], tc.path, c.params, c.skipNds, true) + require.NotNil(t, n) + require.NotNil(t, n.handler) + assert.Empty(t, c.Params()) + assert.Equal(t, tc.wantMatch, n.path) }) } } -func TestRouteConflict(t *testing.T) { +func TestInsertConflict(t *testing.T) { cases := []struct { name string routes []struct { @@ -1300,6 +1043,7 @@ func TestRouteConflict(t *testing.T) { }{ {path: "/john/*{x}", wantErr: nil, wantMatch: nil}, {path: "/john/*{y}", wantErr: ErrRouteConflict, wantMatch: []string{"/john/*{x}"}}, + {path: "/john/", wantErr: ErrRouteExist, wantMatch: nil}, {path: "/foo/baz", wantErr: nil, wantMatch: nil}, {path: "/foo/bar", wantErr: nil, wantMatch: nil}, {path: "/foo/{id}", wantErr: nil, wantMatch: nil}, @@ -1308,6 +1052,8 @@ func TestRouteConflict(t *testing.T) { {path: "/avengers/{id}/bar", wantErr: nil, wantMatch: nil}, {path: "/avengers/{id}/foo", wantErr: nil, wantMatch: nil}, {path: "/avengers/*{args}", wantErr: ErrRouteConflict, wantMatch: []string{"/avengers/{id}/bar", "/avengers/{id}/foo"}}, + {path: "/fox/", wantErr: nil, wantMatch: nil}, + {path: "/fox/*{args}", wantErr: ErrRouteExist, wantMatch: nil}, }, }, { @@ -1366,7 +1112,7 @@ func TestRouteConflict(t *testing.T) { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() for _, rte := range tc.routes { - err := tree.Handler(http.MethodGet, rte.path, emptyHandler) + err := tree.Handle(http.MethodGet, rte.path, emptyHandler) assert.ErrorIs(t, err, rte.wantErr) if cErr, ok := err.(*RouteConflictError); ok { assert.Equal(t, rte.wantMatch, cErr.Matched) @@ -1376,54 +1122,52 @@ func TestRouteConflict(t *testing.T) { } } -func TestSwapWildcardConflict(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) +func TestUpdateConflict(t *testing.T) { cases := []struct { - wantErr error - name string - path string - routes []struct { - path string - wildcard bool - } + name string + routes []string + update string + wantErr error wantMatch []string - wildcard string }{ { - name: "replace existing static node with wildcard", - routes: []struct { - path string - wildcard bool - }{ - {path: "/foo/bar", wildcard: false}, - {path: "/foo/baz", wildcard: false}, - {path: "/foo/", wildcard: false}, - }, - path: "/foo/", - wildcard: "args", + name: "wildcard parameter route not registered", + routes: []string{"/foo/{bar}"}, + update: "/foo/{baz}", + wantErr: ErrRouteNotFound, + }, + { + name: "wildcard catch all route not registered", + routes: []string{"/foo/{bar}"}, + update: "/foo/*{baz}", + wantErr: ErrRouteNotFound, + }, + { + name: "route match but not a leaf", + routes: []string{"/foo/bar/baz"}, + update: "/foo/bar", + wantErr: ErrRouteNotFound, + }, + { + name: "wildcard have different name", + routes: []string{"/foo/bar", "/foo/*{args}"}, + update: "/foo/*{all}", wantErr: ErrRouteConflict, - wantMatch: []string{"/foo/bar", "/foo/baz"}, + wantMatch: []string{"/foo/*{args}"}, }, { - name: "replace existing wildcard node with static", - routes: []struct { - path string - wildcard bool - }{ - {path: "/foo/", wildcard: true}, - }, - path: "/foo/", + name: "replacing non wildcard by wildcard", + routes: []string{"/foo/bar", "/foo/"}, + update: "/foo/*{all}", + wantErr: ErrRouteConflict, + wantMatch: []string{"/foo/"}, }, { - name: "replace existing wildcard node with another wildcard", - routes: []struct { - path string - wildcard bool - }{ - {path: "/foo/", wildcard: true}, - }, - path: "/foo/", - wildcard: "new", + name: "replacing wildcard by non wildcard", + routes: []string{"/foo/bar", "/foo/*{args}"}, + update: "/foo/", + wantErr: ErrRouteConflict, + wantMatch: []string{"/foo/*{args}"}, }, } @@ -1431,13 +1175,9 @@ func TestSwapWildcardConflict(t *testing.T) { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() for _, rte := range tc.routes { - var catchAllKey string - if rte.wildcard { - catchAllKey = "args" - } - require.NoError(t, tree.insert(http.MethodGet, rte.path, catchAllKey, 0, h)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) } - err := tree.update(http.MethodGet, tc.path, tc.wildcard, h) + err := tree.Update(http.MethodGet, tc.update, emptyHandler) assert.ErrorIs(t, err, tc.wantErr) if cErr, ok := err.(*RouteConflictError); ok { assert.Equal(t, tc.wantMatch, cErr.Matched) @@ -1447,63 +1187,55 @@ func TestSwapWildcardConflict(t *testing.T) { } func TestUpdateRoute(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(r.URL.Path)) - }) - cases := []struct { - newHandler Handler - name string - path string - newPath string - newWildcardKey string + name string + routes []string + update string }{ { - name: "update wildcard with another wildcard", - path: "/foo/bar/*{args}", - newPath: "/foo/bar/", - newWildcardKey: "*{new}", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(params.Get(RouteKey))) - }), + name: "replacing ending static node", + routes: []string{"/foo/", "/foo/bar", "/foo/baz"}, + update: "/foo/bar", }, { - name: "update wildcard with non wildcard", - path: "/foo/bar/*{args}", - newPath: "/foo/bar/", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(r.URL.Path)) - }), + name: "replacing middle static node", + routes: []string{"/foo/", "/foo/bar", "/foo/baz"}, + update: "/foo/", }, { - name: "update non wildcard with wildcard", - path: "/foo/bar/", - newPath: "/foo/bar/", - newWildcardKey: "*{foo}", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(params.Get(RouteKey))) - }), + name: "replacing ending wildcard node", + routes: []string{"/foo/", "/foo/bar", "/foo/{baz}"}, + update: "/foo/{baz}", }, { - name: "update non wildcard with non wildcard", - path: "/foo/bar", - newPath: "/foo/bar", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(r.URL.Path)) - }), + name: "replacing ending inflight wildcard node", + routes: []string{"/foo/", "/foo/bar_xyz", "/foo/bar_{baz}"}, + update: "/foo/bar_{baz}", + }, + { + name: "replacing middle wildcard node", + routes: []string{"/foo/{bar}", "/foo/{bar}/baz", "/foo/{bar}/xyz"}, + update: "/foo/{bar}", + }, + { + name: "replacing middle inflight wildcard node", + routes: []string{"/foo/id:{bar}", "/foo/id:{bar}/baz", "/foo/id:{bar}/xyz"}, + update: "/foo/id:{bar}", + }, + { + name: "replacing catch all node", + routes: []string{"/foo/*{bar}", "/foo", "/foo/bar"}, + update: "/foo/*{bar}", }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - r := New(WithSaveMatchedRoute(true)) - require.NoError(t, r.Tree().Handler(http.MethodGet, tc.path, h)) - require.NoError(t, r.Tree().Update(http.MethodGet, tc.newPath+tc.newWildcardKey, tc.newHandler)) - req := httptest.NewRequest(http.MethodGet, tc.newPath, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - require.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, tc.newPath+tc.newWildcardKey, w.Body.String()) + tree := New().Tree() + for _, rte := range tc.routes { + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + } + assert.NoError(t, tree.Update(http.MethodGet, tc.update, emptyHandler)) }) } } @@ -1648,8 +1380,6 @@ func TestParseRoute(t *testing.T) { } func TestTree_LookupTsr(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) - cases := []struct { name string path string @@ -1693,16 +1423,16 @@ func TestTree_LookupTsr(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() - require.NoError(t, tree.insert(http.MethodGet, tc.path, "", 0, h)) - nds := tree.load() - _, _, got := tree.lookup(nds[0], tc.key, true) + require.NoError(t, tree.insert(http.MethodGet, tc.path, "", 0, emptyHandler)) + nds := *tree.nodes.Load() + c := newTestContextTree(tree) + _, got := tree.lookup(nds[0], tc.key, c.params, c.skipNds, true) assert.Equal(t, tc.want, got) }) } } func TestRedirectTrailingSlash(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) cases := []struct { name string @@ -1772,7 +1502,7 @@ func TestRedirectTrailingSlash(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r := New(WithRedirectTrailingSlash(true)) - require.NoError(t, r.Tree().Handler(tc.method, tc.path, h)) + require.NoError(t, r.Tree().Handle(tc.method, tc.path, emptyHandler)) req := httptest.NewRequest(tc.method, tc.key, nil) w := httptest.NewRecorder() @@ -1787,7 +1517,6 @@ func TestRedirectTrailingSlash(t *testing.T) { } func TestRedirectFixedPath(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) cases := []struct { name string path string @@ -1825,7 +1554,7 @@ func TestRedirectFixedPath(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r := New(WithRedirectFixedPath(true), WithRedirectTrailingSlash(tc.tsr)) - require.NoError(t, r.Tree().Handler(http.MethodGet, tc.path, h)) + require.NoError(t, r.Tree().Handle(http.MethodGet, tc.path, emptyHandler)) req, _ := http.NewRequest(http.MethodGet, tc.key, nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -1839,12 +1568,11 @@ func TestRedirectFixedPath(t *testing.T) { func TestTree_Remove(t *testing.T) { tree := New().Tree() - routes := make([]route, len(githubAPI)) copy(routes, githubAPI) for _, rte := range routes { - require.NoError(t, tree.Handler(rte.method, rte.path, emptyHandler)) + require.NoError(t, tree.Handle(rte.method, rte.path, emptyHandler)) } rand.Shuffle(len(routes), func(i, j int) { routes[i], routes[j] = routes[j], routes[i] }) @@ -1854,17 +1582,24 @@ func TestTree_Remove(t *testing.T) { } cnt := 0 - _ = Walk(tree, func(method, path string, handler Handler) error { + _ = Walk(tree, func(method, path string, handler HandlerFunc) error { cnt++ return nil }) assert.Equal(t, 0, cnt) + assert.Equal(t, 4, len(*tree.nodes.Load())) +} + +func TestTree_RemoveRoot(t *testing.T) { + tree := New().Tree() + require.NoError(t, tree.Handle(http.MethodOptions, "/foo/bar", emptyHandler)) + require.NoError(t, tree.Remove(http.MethodOptions, "/foo/bar")) + assert.Equal(t, 4, len(*tree.nodes.Load())) } func TestRouterWithAllowedMethod(t *testing.T) { r := New(WithHandleMethodNotAllowed(true)) - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) cases := []struct { name string @@ -1899,7 +1634,7 @@ func TestRouterWithAllowedMethod(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { - require.NoError(t, r.Tree().Handler(method, tc.path, h)) + require.NoError(t, r.Tree().Handle(method, tc.path, emptyHandler)) } req := httptest.NewRequest(tc.target, tc.path, nil) w := httptest.NewRecorder() @@ -1910,19 +1645,21 @@ func TestRouterWithAllowedMethod(t *testing.T) { } } -func TestPanicHandler(t *testing.T) { - r := New(WithPanicHandler(func(w http.ResponseWriter, r *http.Request, i interface{}) { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(i.(string))) - })) +func TestRecoveryMiddleware(t *testing.T) { + m := Recovery(func(c Context, err any) { + c.Writer().WriteHeader(http.StatusInternalServerError) + _, _ = c.Writer().Write([]byte(err.(string))) + }) + + r := New(WithMiddleware(m)) const errMsg = "unexpected error" - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) { + h := func(c Context) { func() { panic(errMsg) }() - w.Write([]byte("foo")) - }) + _, _ = c.Writer().Write([]byte("foo")) + } - require.NoError(t, r.Tree().Handler(http.MethodPost, "/", h)) + require.NoError(t, r.Tree().Handle(http.MethodPost, "/", h)) req := httptest.NewRequest(http.MethodPost, "/", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -1939,7 +1676,7 @@ func TestHas(t *testing.T) { r := New() for _, rte := range routes { - require.NoError(t, r.Handler(http.MethodGet, rte, emptyHandler)) + require.NoError(t, r.Handle(http.MethodGet, rte, emptyHandler)) } cases := []struct { @@ -1992,7 +1729,7 @@ func TestReverse(t *testing.T) { r := New() for _, rte := range routes { - require.NoError(t, r.Handler(http.MethodGet, rte, emptyHandler)) + require.NoError(t, r.Handle(http.MethodGet, rte, emptyHandler)) } cases := []struct { @@ -2028,103 +1765,20 @@ func TestReverse(t *testing.T) { } } -func TestLookup(t *testing.T) { - routes := []string{ - "/foo/bar", - "/welcome/{name}", - "/users/uid_{id}", - "/john/doe/", - } - - r := New() - for _, rte := range routes { - require.NoError(t, r.Handler(http.MethodGet, rte, emptyHandler)) - } - - cases := []struct { - name string - path string - paramKey string - wantHandler bool - wantParamValue string - wantTsr bool - }{ - { - name: "matching static route", - path: "/foo/bar", - wantHandler: true, - }, - { - name: "tsr remove slash for static route", - path: "/foo/bar/", - wantTsr: true, - }, - { - name: "tsr add slash for static route", - path: "/john/doe", - wantTsr: true, - }, - { - name: "tsr for static route", - path: "/foo/bar/", - wantTsr: true, - }, - { - name: "matching params route", - path: "/welcome/fox", - wantHandler: true, - paramKey: "name", - wantParamValue: "fox", - }, - { - name: "tsr for params route", - path: "/welcome/fox/", - wantTsr: true, - }, - { - name: "matching mid route params", - path: "/users/uid_123", - wantHandler: true, - paramKey: "id", - wantParamValue: "123", - }, - { - name: "matching mid route params", - path: "/users/uid_123/", - wantTsr: true, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - handler, params, tsr := Lookup(r.Tree(), http.MethodGet, tc.path, false) - if params != nil { - defer params.Free(r.Tree()) - } - if tc.wantHandler { - assert.NotNil(t, handler) - } - assert.Equal(t, tc.wantTsr, tsr) - if tc.paramKey != "" { - require.NotNil(t, params) - assert.Equal(t, tc.wantParamValue, params.Get(tc.paramKey)) - } - }) - } -} - func TestAbortHandler(t *testing.T) { - r := New(WithPanicHandler(func(w http.ResponseWriter, r *http.Request, i interface{}) { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(i.(error).Error())) - })) + m := Recovery(func(c Context, err any) { + c.Writer().WriteHeader(http.StatusInternalServerError) + _, _ = c.Writer().Write([]byte(err.(error).Error())) + }) + + r := New(WithMiddleware(m)) - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) { + h := func(c Context) { func() { panic(http.ErrAbortHandler) }() - w.Write([]byte("foo")) - }) + _, _ = c.Writer().Write([]byte("foo")) + } - require.NoError(t, r.Tree().Handler(http.MethodPost, "/", h)) + require.NoError(t, r.Tree().Handle(http.MethodPost, "/", h)) req := httptest.NewRequest(http.MethodPost, "/", nil) w := httptest.NewRecorder() @@ -2149,7 +1803,6 @@ func TestFuzzInsertLookupParam(t *testing.T) { } tree := New().Tree() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) f := fuzz.New().NilChance(0).Funcs(unicodeRanges.CustomStringFuzzFunc()) routeFormat := "/%s/{%s}/%s/{%s}/{%s}" reqFormat := "/%s/%s/%s/%s/%s" @@ -2163,15 +1816,16 @@ func TestFuzzInsertLookupParam(t *testing.T) { if s1 == "" || s2 == "" || e1 == "" || e2 == "" || e3 == "" { continue } - if err := tree.insert(http.MethodGet, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), "", 3, h); err == nil { - nds := tree.load() + if err := tree.insert(http.MethodGet, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), "", 3, emptyHandler); err == nil { + nds := *tree.nodes.Load() - n, params, _ := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), false) + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c.params, c.skipNds, false) require.NotNil(t, n) assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.path) - assert.Equal(t, "xxxx", params.Get(e1)) - assert.Equal(t, "xxxx", params.Get(e2)) - assert.Equal(t, "xxxx", params.Get(e3)) + assert.Equal(t, "xxxx", c.Param(e1)) + assert.Equal(t, "xxxx", c.Param(e2)) + assert.Equal(t, "xxxx", c.Param(e3)) } } } @@ -2179,7 +1833,6 @@ func TestFuzzInsertLookupParam(t *testing.T) { func TestFuzzInsertNoPanics(t *testing.T) { f := fuzz.New().NilChance(0).NumElements(5000, 10000) tree := New().Tree() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) routes := make(map[string]struct{}) f.Fuzz(&routes) @@ -2191,7 +1844,7 @@ func TestFuzzInsertNoPanics(t *testing.T) { continue } require.NotPanicsf(t, func() { - _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, h) + _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, emptyHandler) }, fmt.Sprintf("rte: %s, catch all: %s", rte, catchAllKey)) } } @@ -2207,30 +1860,30 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { f := fuzz.New().NilChance(0).NumElements(1000, 2000).Funcs(unicodeRanges.CustomStringFuzzFunc()) tree := New().Tree() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) routes := make(map[string]struct{}) f.Fuzz(&routes) for rte := range routes { - err := tree.insert(http.MethodGet, "/"+rte, "", 0, h) + err := tree.insert(http.MethodGet, "/"+rte, "", 0, emptyHandler) require.NoError(t, err) } countPath := 0 - require.NoError(t, Walk(tree, func(method, path string, handler Handler) error { + require.NoError(t, Walk(tree, func(method, path string, handler HandlerFunc) error { countPath++ return nil })) assert.Equal(t, len(routes), countPath) for rte := range routes { - nds := tree.load() - n, _, _ := tree.lookup(nds[0], "/"+rte, true) + nds := *tree.nodes.Load() + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], "/"+rte, c.params, c.skipNds, true) require.NotNilf(t, n, "route /%s", rte) require.Truef(t, n.isLeaf(), "route /%s", rte) require.Equal(t, "/"+rte, n.path) - require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", h)) + require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", emptyHandler)) } for rte := range routes { @@ -2239,7 +1892,7 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { } countPath = 0 - require.NoError(t, Walk(tree, func(method, path string, handler Handler) error { + require.NoError(t, Walk(tree, func(method, path string, handler HandlerFunc) error { countPath++ return nil })) @@ -2250,8 +1903,8 @@ func TestDataRace(t *testing.T) { var wg sync.WaitGroup start, wait := atomicSync() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) {}) - newH := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) {}) + h := HandlerFunc(func(c Context) {}) + newH := HandlerFunc(func(c Context) {}) r := New() @@ -2269,8 +1922,8 @@ func TestDataRace(t *testing.T) { assert.NoError(t, tree.Update(method, route, h)) return } - assert.NoError(t, tree.Handler(method, route, h)) - // assert.NoError(t, r.Handler("PING", route, h)) + assert.NoError(t, tree.Handle(method, route, h)) + // assert.NoError(t, r.Handle("PING", route, h)) }(rte.method, rte.path) go func(method, route string) { @@ -2283,7 +1936,7 @@ func TestDataRace(t *testing.T) { assert.NoError(t, tree.Remove(method, route)) return } - assert.NoError(t, tree.Handler(method, route, newH)) + assert.NoError(t, tree.Handle(method, route, newH)) }(rte.method, rte.path) go func(method, route string) { @@ -2300,32 +1953,32 @@ func TestDataRace(t *testing.T) { } func TestConcurrentRequestHandling(t *testing.T) { - r := New(WithSaveMatchedRoute(true)) + r := New() // /repos/{owner}/{repo}/keys - h1 := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - assert.Equal(t, "john", params.Get("owner")) - assert.Equal(t, "fox", params.Get("repo")) - _, _ = fmt.Fprint(w, params.Get(RouteKey)) + h1 := HandlerFunc(func(c Context) { + assert.Equal(t, "john", c.Param("owner")) + assert.Equal(t, "fox", c.Param("repo")) + _, _ = fmt.Fprint(c.Writer(), c.Path()) }) // /repos/{owner}/{repo}/contents/*{path} - h2 := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - assert.Equal(t, "alex", params.Get("owner")) - assert.Equal(t, "vault", params.Get("repo")) - assert.Equal(t, "file.txt", params.Get("path")) - _, _ = fmt.Fprint(w, params.Get(RouteKey)) + h2 := HandlerFunc(func(c Context) { + assert.Equal(t, "alex", c.Param("owner")) + assert.Equal(t, "vault", c.Param("repo")) + assert.Equal(t, "file.txt", c.Param("path")) + _, _ = fmt.Fprint(c.Writer(), c.Path()) }) // /users/{user}/received_events/public - h3 := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - assert.Equal(t, "go", params.Get("user")) - _, _ = fmt.Fprint(w, params.Get(RouteKey)) + h3 := HandlerFunc(func(c Context) { + assert.Equal(t, "go", c.Param("user")) + _, _ = fmt.Fprint(c.Writer(), c.Path()) }) - require.NoError(t, r.Handler(http.MethodGet, "/repos/{owner}/{repo}/keys", h1)) - require.NoError(t, r.Handler(http.MethodGet, "/repos/{owner}/{repo}/contents/*{path}", h2)) - require.NoError(t, r.Handler(http.MethodGet, "/users/{user}/received_events/public", h3)) + require.NoError(t, r.Handle(http.MethodGet, "/repos/{owner}/{repo}/keys", h1)) + require.NoError(t, r.Handle(http.MethodGet, "/repos/{owner}/{repo}/contents/*{path}", h2)) + require.NoError(t, r.Handle(http.MethodGet, "/users/{user}/received_events/public", h3)) r1 := httptest.NewRequest(http.MethodGet, "/repos/john/fox/keys", nil) r2 := httptest.NewRequest(http.MethodGet, "/repos/alex/vault/contents/file.txt", nil) @@ -2380,51 +2033,54 @@ func atomicSync() (start func(), wait func()) { return } -// When WithSaveMatchedRoute is enabled, the route matching the current request will be available in parameters. +// This example demonstrates how to create a simple router using the default options, +// which include the Recovery middleware. A basic route is defined, along with a +// custom middleware to log the request metrics. func ExampleNew() { - r := New(WithSaveMatchedRoute(true)) - metrics := func(next HandlerFunc) Handler { - return HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { + // Create a new router with default options, which include the Recovery middleware + r := New(DefaultOptions()) + + // Define a custom middleware to measure the time taken for request processing and + // log the URL, route, time elapsed, and status code + metrics := func(next HandlerFunc) HandlerFunc { + return func(c Context) { start := time.Now() - next.ServeHTTP(w, r, params) - log.Printf("url=%s; route=%s; time=%d", r.URL, params.Get(RouteKey), time.Since(start)) - }) + next(c) + log.Printf("url=%s; route=%s; time=%d; status=%d", c.Request().URL, c.Path(), time.Since(start), c.Writer().Status()) + } } - _ = r.Handler(http.MethodGet, "/hello/{name}", metrics(func(w http.ResponseWriter, r *http.Request, params Params) { - _, _ = fmt.Fprintf(w, "Hello %s\n", params.Get("name")) + // Define a route with the path "/hello/{name}", apply the custom "metrics" middleware, + // and set a simple handler that greets the user by their name + r.MustHandle(http.MethodGet, "/hello/{name}", metrics(func(c Context) { + _ = c.String(200, "Hello %s\n", c.Param("name")) })) + + // Start the HTTP server using the router as the handler and listen on port 8080 + log.Fatalln(http.ListenAndServe(":8080", r)) } -// This example demonstrates some important considerations when using the Lookup function. -func ExampleLookup() { - r := New() - _ = r.Handler(http.MethodGet, "/hello/{name}", HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - _, _ = fmt.Fprintf(w, "Hello, %s\n", params.Get("name")) - })) +// This example demonstrates how to register a global middleware that will be +// applied to all routes. - req := httptest.NewRequest(http.MethodGet, "/hello/fox", nil) +func ExampleWithMiddleware() { - // Each tree as its own sync.Pool that is used to reuse Params slice. Since the router tree may be swapped at - // any given time, it's recommended to copy the pointer locally so when the params is released, - // it returns to the correct pool. - tree := r.Tree() - handler, params, _ := Lookup(tree, http.MethodGet, req.URL.Path, false) - // If not nit, Params should be freed to reduce memory allocation. - if params != nil { - defer params.Free(tree) + // Define a custom middleware to measure the time taken for request processing and + // log the URL, route, time elapsed, and status code + metrics := func(next HandlerFunc) HandlerFunc { + return func(c Context) { + start := time.Now() + next(c) + log.Printf("url=%s; route=%s; time=%d; status=%d", c.Request().URL, c.Path(), time.Since(start), c.Writer().Status()) + } } - // Bad, instead make a local copy of the tree! - handler, params, _ = Lookup(r.Tree(), http.MethodGet, req.URL.Path, false) - if params != nil { - defer params.Free(r.Tree()) - } + r := New(WithMiddleware(metrics)) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req, nil) - fmt.Print(w.Body.String()) + r.MustHandle(http.MethodGet, "/hello/{name}", func(c Context) { + _ = c.String(200, "Hello %s\n", c.Param("name")) + }) } // This example demonstrates some important considerations when using the Tree API. @@ -2435,26 +2091,29 @@ func ExampleRouter_Tree() { // any given time, you MUST always copy the pointer locally, This ensures that you do not inadvertently cause a // deadlock by locking/unlocking the wrong tree. tree := r.Tree() - upsert := func(method, path string, handler Handler) error { + upsert := func(method, path string, handler HandlerFunc) error { tree.Lock() defer tree.Unlock() if Has(tree, method, path) { return tree.Update(method, path, handler) } - return tree.Handler(method, path, handler) + return tree.Handle(method, path, handler) } - _ = upsert(http.MethodGet, "/foo/bar", HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - _, _ = fmt.Fprintln(w, "foo bar") - })) + _ = upsert(http.MethodGet, "/foo/bar", func(c Context) { + // Note the tree accessible from fox.Context is already a local copy so the golden rule above does not apply. + c.Tree().Lock() + defer c.Tree().Unlock() + _, _ = fmt.Fprintln(c.Writer(), "foo bar") + }) // Bad, instead make a local copy of the tree! - upsert = func(method, path string, handler Handler) error { + upsert = func(method, path string, handler HandlerFunc) error { r.Tree().Lock() defer r.Tree().Unlock() if Has(r.Tree(), method, path) { return r.Tree().Update(method, path, handler) } - return r.Tree().Handler(method, path, handler) + return r.Tree().Handle(method, path, handler) } } diff --git a/go.mod b/go.mod index f7a9b6e..5e7e85a 100644 --- a/go.mod +++ b/go.mod @@ -3,38 +3,15 @@ module github.com/tigerwill90/fox go 1.19 require ( - github.com/gin-gonic/gin v1.9.0 github.com/google/gofuzz v1.2.0 github.com/stretchr/testify v1.8.2 ) require ( - github.com/bytedance/sonic v1.8.6 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gin-contrib/sse v0.1.0 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.12.0 // indirect - github.com/goccy/go-json v0.10.2 // indirect - github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/kr/pretty v0.3.0 // indirect - github.com/leodido/go-urn v1.2.2 // indirect - github.com/mattn/go-isatty v0.0.18 // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.0.7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect - github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect - golang.org/x/arch v0.3.0 // indirect - golang.org/x/crypto v0.7.0 // indirect - golang.org/x/net v0.8.0 // indirect - golang.org/x/sys v0.6.0 // indirect - golang.org/x/text v0.8.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bc886df..66c3718 100644 --- a/go.sum +++ b/go.sum @@ -1,37 +1,9 @@ -github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.8.6 h1:aUgO9S8gvdN6SyW2EhIpAw5E4ChworywIEndZCkCVXk= -github.com/bytedance/sonic v1.8.6/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8= -github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.12.0 h1:E4gtWgxWxp8YSxExrQFv5BpCahla0PVF2oTTEYaWQGI= -github.com/go-playground/validator/v10 v10.12.0/go.mod h1:hCAPuzYvKdP33pxWa+2+6AIKXEKqjIUyqsNCtbsSJrA= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= @@ -40,55 +12,19 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.2 h1:7z68G0FCGvDk646jz1AelTYNYWrTNm0bEcFAo147wt4= -github.com/leodido/go-urn v1.2.2/go.mod h1:kUaIbLZWttglzwNuG0pgsh5vuV6u2YcGBYz1hIPjtOQ= -github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= -github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pelletier/go-toml/v2 v2.0.7 h1:muncTPStnKRos5dpVKULv2FVd4bMOhNePj9CjgDb8Us= -github.com/pelletier/go-toml/v2 v2.0.7/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/rwtodd/Go.Sed v0.0.0-20210816025313-55464686f9ef/go.mod h1:8AEUvGVi2uQ5b24BIhcr0GCcpd/RNAFWaN2CJFrWIIQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= -github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= -github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= -golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= @@ -97,4 +33,3 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/http_consts.go b/http_consts.go new file mode 100644 index 0000000..764d0bf --- /dev/null +++ b/http_consts.go @@ -0,0 +1,86 @@ +// Copyright (c) 2021 LabStack, see https://github.com/labstack/echo/blob/master/LICENSE. +// Portions of this code were derived from the Echo project (https://github.com/labstack/echo) +// under the MIT License. + +package fox + +// MIME types +const ( + charsetUTF8 = "charset=UTF-8" + MIMEApplicationJSON = "application/json" + MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 + MIMEApplicationJavaScript = "application/javascript" + MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 + MIMEApplicationXML = "application/xml" + MIMEApplicationXMLCharsetUTF8 = MIMEApplicationXML + "; " + charsetUTF8 + MIMETextXML = "text/xml" + MIMETextXMLCharsetUTF8 = MIMETextXML + "; " + charsetUTF8 + MIMEApplicationForm = "application/x-www-form-urlencoded" + MIMEApplicationProtobuf = "application/protobuf" + MIMEApplicationMsgpack = "application/msgpack" + MIMETextHTML = "text/html" + MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8 + MIMETextPlain = "text/plain" + MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8 + MIMEMultipartForm = "multipart/form-data" + MIMEOctetStream = "application/octet-stream" +) + +// Headers +const ( + HeaderAccept = "Accept" + HeaderAcceptEncoding = "Accept-Encoding" + // HeaderAllow is the name of the "Allow" header field used to list the set of methods + // advertised as supported by the target resource. Returning an Allow header is mandatory + // for status 405 (method not found) and useful for the OPTIONS method in responses. + // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 + HeaderAllow = "Allow" + HeaderAuthorization = "Authorization" + HeaderContentDisposition = "Content-Disposition" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderCookie = "Cookie" + HeaderSetCookie = "Set-Cookie" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderLastModified = "Last-Modified" + HeaderLocation = "Location" + HeaderRetryAfter = "Retry-After" + HeaderUpgrade = "Upgrade" + HeaderVary = "Vary" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderXForwardedSsl = "X-Forwarded-Ssl" + HeaderXUrlScheme = "X-Url-Scheme" + HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" + HeaderXRealIP = "X-Real-Ip" + HeaderXRequestID = "X-Request-Id" + HeaderXCorrelationID = "X-Correlation-Id" + HeaderXRequestedWith = "X-Requested-With" + HeaderServer = "Server" + HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" + + // Access control + HeaderAccessControlRequestMethod = "Access-Control-Request-Method" + HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" + HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" + HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" + HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" + HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" + HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" + HeaderAccessControlMaxAge = "Access-Control-Max-Age" + + // Security + HeaderStrictTransportSecurity = "Strict-Transport-Security" + HeaderXContentTypeOptions = "X-Content-Type-Options" + HeaderXXSSProtection = "X-XSS-Protection" + HeaderXFrameOptions = "X-Frame-Options" + HeaderContentSecurityPolicy = "Content-Security-Policy" + HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" + HeaderXCSRFToken = "X-CSRF-Token" + HeaderReferrerPolicy = "Referrer-Policy" +) diff --git a/iter.go b/iter.go index 541c571..9344deb 100644 --- a/iter.go +++ b/iter.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -24,7 +28,7 @@ func NewIterator(t *Tree) *Iterator { } func (it *Iterator) methods() map[string]*node { - nds := it.tree.load() + nds := *it.tree.nodes.Load() m := make(map[string]*node, len(nds)) for i := range nds { if len(nds[i].children) > 0 { @@ -174,7 +178,7 @@ func (it *Iterator) Method() string { } // Handler return the registered handler for the current route. -func (it *Iterator) Handler() Handler { +func (it *Iterator) Handler() HandlerFunc { if it.current != nil { return it.current.handler } diff --git a/iter_test.go b/iter_test.go index 67479b5..8e79e59 100644 --- a/iter_test.go +++ b/iter_test.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -13,9 +17,9 @@ var routesCases = []string{"/fox/router", "/foo/bar/{baz}", "/foo/bar/{baz}/{nam func TestIterator_Rewind(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } results := make(map[string][]string) @@ -34,9 +38,9 @@ func TestIterator_Rewind(t *testing.T) { func TestIterator_SeekMethod(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } results := make(map[string][]string) @@ -54,9 +58,9 @@ func TestIterator_SeekMethod(t *testing.T) { func TestIterator_SeekPrefix(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } want := []string{"/foo/bar/{baz}", "/foo/bar/{baz}/{name}"} @@ -76,9 +80,9 @@ func TestIterator_SeekPrefix(t *testing.T) { func TestIterator_SeekMethodPrefix(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } want := []string{"/foo/bar/{baz}", "/foo/bar/{baz}/{name}"} diff --git a/node.go b/node.go index 15422f0..01ebe06 100644 --- a/node.go +++ b/node.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -10,7 +14,7 @@ import ( type node struct { // The registered handler matching the full path. Nil if the node is not a leaf. // Once assigned, handler is immutable. - handler Handler + handler HandlerFunc // key represent a segment of a route which share a common prefix with it parent. key string @@ -36,7 +40,7 @@ type node struct { paramChildIndex int } -func newNode(key string, handler Handler, children []*node, catchAllKey string, path string) *node { +func newNode(key string, handler HandlerFunc, children []*node, catchAllKey string, path string) *node { sort.Slice(children, func(i, j int) bool { return children[i].key < children[j].key }) @@ -55,23 +59,17 @@ func newNode(key string, handler Handler, children []*node, catchAllKey string, return newNodeFromRef(key, handler, nds, childKeys, catchAllKey, childIndex, path) } -func newNodeFromRef(key string, handler Handler, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int, path string) *node { +func newNodeFromRef(key string, handler HandlerFunc, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int, path string) *node { n := &node{ key: key, childKeys: childKeys, children: children, handler: handler, catchAllKey: catchAllKey, - path: path, + path: appendCatchAll(path, catchAllKey), paramChildIndex: childIndex, } - // TODO find a better way - if catchAllKey != "" { - suffix := "*{" + catchAllKey + "}" - if !strings.HasSuffix(path, suffix) { - n.path += suffix - } - } + return n } @@ -209,14 +207,6 @@ func (n *node) string(space int) string { type skippedNodes []skippedNode -func (n *skippedNodes) free(t *Tree) { - if cap(*n) < int(t.maxDepth.Load()) { - return - } - *n = (*n)[:0] - t.np.Put(n) -} - func (n *skippedNodes) pop() skippedNode { skipped := (*n)[len(*n)-1] *n = (*n)[:len(*n)-1] @@ -227,3 +217,13 @@ type skippedNode struct { node *node pathIndex int } + +func appendCatchAll(path, catchAllKey string) string { + if catchAllKey != "" { + suffix := "*{" + catchAllKey + "}" + if !strings.HasSuffix(path, suffix) { + return path + suffix + } + } + return path +} diff --git a/options.go b/options.go index 2fd2f2e..635705d 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,8 @@ -package fox +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. -import "net/http" +package fox type Option interface { apply(*Router) @@ -12,34 +14,33 @@ func (o optionFunc) apply(r *Router) { o(r) } -// WithNotFoundHandler register a http.Handler which is called when no matching route is found. +// WithNoRouteHandler register a http.Handler which is called when no matching route is found. // By default, http.NotFound is used. -func WithNotFoundHandler(handler http.Handler) Option { +func WithNoRouteHandler(handler HandlerFunc) Option { return optionFunc(func(r *Router) { if handler != nil { - r.notFound = handler + r.noRoute = handler } }) } -// WithNotAllowedHandler register a http.Handler which is called when the request cannot be routed, +// WithNoMethodHandler register a http.Handler which is called when the request cannot be routed, // but the same route exist for other methods. The "Allow" header it automatically set -// before calling the handler. Mount WithHandleMethodNotAllowed to enable this option. By default, +// before calling the handler. Set WithHandleMethodNotAllowed to enable this option. By default, // http.Error with http.StatusMethodNotAllowed is used. -func WithNotAllowedHandler(handler http.Handler) Option { +func WithNoMethodHandler(handler HandlerFunc) Option { return optionFunc(func(r *Router) { if handler != nil { - r.methodNotAllowed = handler + r.noMethod = handler } }) } -// WithPanicHandler register a function to handle panics recovered from http handlers. -func WithPanicHandler(fn func(http.ResponseWriter, *http.Request, interface{})) Option { +// WithMiddleware attaches a global middleware to the router. Middlewares provided will be chained +// in the order they were added. Note that it does NOT apply the middlewares to the NotFound and MethodNotAllowed handlers. +func WithMiddleware(middlewares ...MiddlewareFunc) Option { return optionFunc(func(r *Router) { - if fn != nil { - r.panicHandler = fn - } + r.mws = append(r.mws, middlewares...) }) } @@ -71,10 +72,9 @@ func WithRedirectTrailingSlash(enable bool) Option { }) } -// WithSaveMatchedRoute configure the router to make the matched route accessible as a Handler parameter. -// Usage: p.Get(fox.RouteKey) -func WithSaveMatchedRoute(enable bool) Option { +// DefaultOptions configure the router to use the Recovery middleware. +func DefaultOptions() Option { return optionFunc(func(r *Router) { - r.saveMatchedRoute = enable + r.mws = append(r.mws, Recovery(defaultHandleRecovery)) }) } diff --git a/params.go b/params.go index 6715b12..95830f5 100644 --- a/params.go +++ b/params.go @@ -1,15 +1,8 @@ -package fox - -import ( - "context" - "net/http" -) - -const RouteKey = "$k/fox" - -var ParamsKey = key{} +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. -type key struct{} +package fox type Param struct { Key string @@ -19,57 +12,18 @@ type Param struct { type Params []Param // Get the matching wildcard segment by name. -func (p *Params) Get(name string) string { - for i := range *p { - if (*p)[i].Key == name { - return (*p)[i].Value +func (p Params) Get(name string) string { + for i := range p { + if p[i].Key == name { + return p[i].Value } } return "" } // Clone make a copy of Params. -func (p *Params) Clone() Params { - cloned := make(Params, len(*p)) - copy(cloned, *p) +func (p Params) Clone() Params { + cloned := make(Params, len(p)) + copy(cloned, p) return cloned } - -// Free release the params to be reused later. -func (p *Params) Free(t *Tree) { - if cap(*p) < int(t.maxParams.Load()) { - return - } - *p = (*p)[:0] - t.pp.Put(p) -} - -// ParamsFromContext is a helper function to retrieve parameters from the request context. -func ParamsFromContext(ctx context.Context) Params { - p, _ := ctx.Value(ParamsKey).(Params) - return p -} - -// WrapF is a helper function for wrapping http.HandlerFunc and returns a Fox Handler. -// Params are forwarded via the request context. See ParamsFromContext to retrieve parameters. -func WrapF(f http.HandlerFunc) Handler { - return HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - if len(params) > 0 { - f.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ParamsKey, params))) - return - } - f.ServeHTTP(w, r) - }) -} - -// WrapH is a helper function for wrapping http.Handler and returns a Fox Handler. -// Params are forwarded via the request context. See ParamsFromContext to retrieve parameters. -func WrapH(h http.Handler) Handler { - return HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - if len(params) > 0 { - h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ParamsKey, params))) - return - } - h.ServeHTTP(w, r) - }) -} diff --git a/params_test.go b/params_test.go index e79d3e9..06aa842 100644 --- a/params_test.go +++ b/params_test.go @@ -1,91 +1,33 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" "testing" ) -func TestWrapHandler(t *testing.T) { - tree := New().Tree() - - cases := []struct { - name string - h Handler - params Params - }{ - { - name: "wrapf with params", - h: WrapF(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Equal(t, "bar", params.Get("foo")) - assert.Equal(t, "doe", params.Get("john")) - }), - params: Params{ - Param{ - Key: "foo", - Value: "bar", - }, - Param{ - Key: "john", - Value: "doe", - }, - }, - }, - { - name: "wraph with params", - h: WrapH(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Equal(t, "bar", params.Get("foo")) - assert.Equal(t, "doe", params.Get("john")) - })), - params: Params{ - Param{ - Key: "foo", - Value: "bar", - }, - Param{ - Key: "john", - Value: "doe", - }, - }, - }, - { - name: "wrapf no params", - h: WrapF(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Nil(t, params) - }), - params: nil, +func TestParams_Get(t *testing.T) { + params := make(Params, 0, 2) + params = append(params, + Param{ + Key: "foo", + Value: "bar", }, - { - name: "wraph no params", - h: WrapH(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Nil(t, params) - })), - params: nil, + Param{ + Key: "john", + Value: "doe", }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - params := tree.newParams() - defer params.Free(tree) - *params = append(*params, tc.params...) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - tc.h.ServeHTTP(nil, req, *params) - }) - } + ) + assert.Equal(t, "bar", params.Get("foo")) + assert.Equal(t, "doe", params.Get("john")) } -func TestParamsClone(t *testing.T) { - tree := New().Tree() - params := tree.newParams() - defer params.Free(tree) - *params = append(*params, +func TestParams_Clone(t *testing.T) { + params := make(Params, 0, 2) + params = append(params, Param{ Key: "foo", Value: "bar", @@ -95,5 +37,5 @@ func TestParamsClone(t *testing.T) { Value: "doe", }, ) - assert.Equal(t, *params, params.Clone()) + assert.Equal(t, params, params.Clone()) } diff --git a/path_test.go b/path_test.go new file mode 100644 index 0000000..748ceb9 --- /dev/null +++ b/path_test.go @@ -0,0 +1,149 @@ +// Copyright 2013 Julien Schmidt. All rights reserved. +// Based on the path package, Copyright 2009 The Go Authors. +// Use of this source code is governed by a BSD-style license that can be found +// in the LICENSE file. + +package fox + +import ( + "strings" + "testing" +) + +type cleanPathTest struct { + path, result string +} + +var cleanTests = []cleanPathTest{ + // Already clean + {"/", "/"}, + {"/abc", "/abc"}, + {"/a/b/c", "/a/b/c"}, + {"/abc/", "/abc/"}, + {"/a/b/c/", "/a/b/c/"}, + + // missing root + {"", "/"}, + {"a/", "/a/"}, + {"abc", "/abc"}, + {"abc/def", "/abc/def"}, + {"a/b/c", "/a/b/c"}, + + // Remove doubled slash + {"//", "/"}, + {"/abc//", "/abc/"}, + {"/abc/def//", "/abc/def/"}, + {"/a/b/c//", "/a/b/c/"}, + {"/abc//def//ghi", "/abc/def/ghi"}, + {"//abc", "/abc"}, + {"///abc", "/abc"}, + {"//abc//", "/abc/"}, + + // Remove . elements + {".", "/"}, + {"./", "/"}, + {"/abc/./def", "/abc/def"}, + {"/./abc/def", "/abc/def"}, + {"/abc/.", "/abc/"}, + + // Remove .. elements + {"..", "/"}, + {"../", "/"}, + {"../../", "/"}, + {"../..", "/"}, + {"../../abc", "/abc"}, + {"/abc/def/ghi/../jkl", "/abc/def/jkl"}, + {"/abc/def/../ghi/../jkl", "/abc/jkl"}, + {"/abc/def/..", "/abc"}, + {"/abc/def/../..", "/"}, + {"/abc/def/../../..", "/"}, + {"/abc/def/../../..", "/"}, + {"/abc/def/../../../ghi/jkl/../../../mno", "/mno"}, + + // Combinations + {"abc/./../def", "/def"}, + {"abc//./../def", "/def"}, + {"abc/../../././../def", "/def"}, +} + +func TestPathClean(t *testing.T) { + for _, test := range cleanTests { + if s := CleanPath(test.path); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result) + } + if s := CleanPath(test.result); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result) + } + } +} + +func TestPathCleanMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } + + for _, test := range cleanTests { + test := test + allocs := testing.AllocsPerRun(100, func() { CleanPath(test.result) }) + if allocs > 0 { + t.Errorf("CleanPath(%q): %v allocs, want zero", test.result, allocs) + } + } +} + +func BenchmarkPathClean(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, test := range cleanTests { + CleanPath(test.path) + } + } +} + +func genLongPaths() (testPaths []cleanPathTest) { + for i := 1; i <= 1234; i++ { + ss := strings.Repeat("a", i) + + correctPath := "/" + ss + testPaths = append(testPaths, cleanPathTest{ + path: correctPath, + result: correctPath, + }, cleanPathTest{ + path: ss, + result: correctPath, + }, cleanPathTest{ + path: "//" + ss, + result: correctPath, + }, cleanPathTest{ + path: "/" + ss + "/b/..", + result: correctPath, + }) + } + return testPaths +} + +func TestPathCleanLong(t *testing.T) { + cleanTests := genLongPaths() + + for _, test := range cleanTests { + if s := CleanPath(test.path); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result) + } + if s := CleanPath(test.result); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result) + } + } +} + +func BenchmarkPathCleanLong(b *testing.B) { + cleanTests := genLongPaths() + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, test := range cleanTests { + CleanPath(test.path) + } + } +} diff --git a/recovery.go b/recovery.go new file mode 100644 index 0000000..5380ca0 --- /dev/null +++ b/recovery.go @@ -0,0 +1,61 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + +package fox + +import ( + "errors" + "log" + "net" + "net/http" + "os" + "runtime/debug" + "strings" +) + +var stdErr = log.New(os.Stderr, "", log.LstdFlags) + +// RecoveryFunc is a function type that defines how to handle panics that occur during the +// handling of an HTTP request. +type RecoveryFunc func(c Context, err any) + +// Recovery is a middleware that captures panics and recovers from them. It takes a custom handle function +// that will be called with the Context and the value recovered from the panic. +// Note that the middleware check if the panic is caused by http.ErrAbortHandler and re-panic if true +// allowing the http server to handle it as an abort. +func Recovery(handle RecoveryFunc) MiddlewareFunc { + return func(next HandlerFunc) HandlerFunc { + return func(c Context) { + defer recovery(c, handle) + next(c) + } + } +} + +func recovery(c Context, handle RecoveryFunc) { + if err := recover(); err != nil { + if abortErr, ok := err.(error); ok && errors.Is(abortErr, http.ErrAbortHandler) { + panic(abortErr) + } + handle(c, err) + } +} + +func defaultHandleRecovery(c Context, err any) { + stdErr.Printf("[PANIC] %q panic recovered\n%s", err, debug.Stack()) + if !c.Writer().Written() && !connIsBroken(err) { + http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } +} + +func connIsBroken(err any) bool { + if ne, ok := err.(*net.OpError); ok { + var se *os.SyscallError + if errors.As(ne, &se) { + seStr := strings.ToLower(se.Error()) + return strings.Contains(seStr, "broken pipe") || strings.Contains(seStr, "connection reset by peer") + } + } + return false +} diff --git a/response_writer.go b/response_writer.go new file mode 100644 index 0000000..5d256cd --- /dev/null +++ b/response_writer.go @@ -0,0 +1,177 @@ +// ResponseRecorder is influenced by the work done by goji and chi libraries, +// with additional optimizations to avoid unnecessary memory allocations. +// See their respective licenses for more information: +// https://github.com/zenazn/goji/blob/master/LICENSE +// https://github.com/go-chi/chi/blob/master/LICENSE + +package fox + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" +) + +var _ http.Flusher = (*h1Writer)(nil) +var _ http.Hijacker = (*h1Writer)(nil) +var _ io.ReaderFrom = (*h1Writer)(nil) +var _ http.Pusher = (*h2Writer)(nil) +var _ http.Flusher = (*h2Writer)(nil) + +// ResponseWriter extends http.ResponseWriter and provides +// methods to retrieve the recorded status code, written state, and response size. +type ResponseWriter interface { + http.ResponseWriter + // Status recorded after Write and WriteHeader. + Status() int + // Written returns true if the response has been written. + Written() bool + // Size returns the size of the written response. + Size() int + // Unwrap returns the underlying http.ResponseWriter. + Unwrap() http.ResponseWriter +} + +const notWritten = -1 + +type recorder struct { + http.ResponseWriter + size int + status int +} + +func (r *recorder) reset(w http.ResponseWriter) { + r.ResponseWriter = w + r.size = notWritten + r.status = http.StatusOK +} + +func (r *recorder) Status() int { + return r.status +} + +func (r *recorder) Written() bool { + return r.size != notWritten +} + +func (r *recorder) Size() int { + return r.size +} + +func (r *recorder) Unwrap() http.ResponseWriter { + return r.ResponseWriter +} + +func (r *recorder) WriteHeader(code int) { + if !r.Written() { + r.size = 0 + r.status = code + } + r.ResponseWriter.WriteHeader(code) +} + +func (r *recorder) Write(buf []byte) (n int, err error) { + if !r.Written() { + r.size = 0 + r.ResponseWriter.WriteHeader(r.status) + } + n, err = r.ResponseWriter.Write(buf) + r.size += n + return +} + +type hijackWriter struct { + *recorder +} + +func (w *hijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !w.recorder.Written() { + w.recorder.size = 0 + } + return w.recorder.ResponseWriter.(http.Hijacker).Hijack() +} + +type flushHijackWriter struct { + *recorder +} + +func (w *flushHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !w.recorder.Written() { + w.recorder.size = 0 + } + return w.recorder.ResponseWriter.(http.Hijacker).Hijack() +} + +func (w *flushHijackWriter) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type flushWriter struct { + *recorder +} + +func (w *flushWriter) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type h1Writer struct { + *recorder +} + +func (w *h1Writer) ReadFrom(r io.Reader) (n int64, err error) { + rf := w.recorder.ResponseWriter.(io.ReaderFrom) + // If not written, status is OK + w.recorder.WriteHeader(w.recorder.status) + n, err = rf.ReadFrom(r) + w.recorder.size += int(n) + return +} + +func (w *h1Writer) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !w.recorder.Written() { + w.recorder.size = 0 + } + return w.recorder.ResponseWriter.(http.Hijacker).Hijack() +} + +func (w *h1Writer) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type h2Writer struct { + *recorder +} + +func (w *h2Writer) Push(target string, opts *http.PushOptions) error { + return w.recorder.ResponseWriter.(http.Pusher).Push(target, opts) +} + +func (w *h2Writer) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type noopWriter struct{} + +func (n noopWriter) Header() http.Header { + return make(http.Header) +} + +func (n noopWriter) Write(bytes []byte) (int, error) { + return 0, fmt.Errorf("%w: writing on a clone", ErrDiscardedResponseWriter) +} + +func (n noopWriter) WriteHeader(statusCode int) {} diff --git a/test_helpers.go b/test_helpers.go new file mode 100644 index 0000000..00588d3 --- /dev/null +++ b/test_helpers.go @@ -0,0 +1,32 @@ +package fox + +import ( + "net/http" +) + +// NewTestContext returns a new Router and its associated Context, designed only for testing purpose. +func NewTestContext(w http.ResponseWriter, r *http.Request) (*Router, Context) { + fox := New() + c := NewTestContextOnly(fox, w, r) + return fox, c +} + +func NewTestContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) Context { + return newTextContextOnly(fox, w, r) +} + +func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *context { + c := fox.Tree().allocateContext() + c.resetNil() + c.rec.reset(w) + c.w = flushWriter{&c.rec} + c.fox = fox + c.req = r + return c +} + +func newTestContextTree(t *Tree) *context { + c := t.allocateContext() + c.resetNil() + return c +} diff --git a/tree.go b/tree.go index f9ab089..b96079f 100644 --- a/tree.go +++ b/tree.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -11,10 +15,9 @@ import ( // The caller is responsible for ensuring that all writes are run serially. // // IMPORTANT: -// Each tree as its own sync.Mutex and sync.Pool that may be used to serialize write and reduce memory allocation. -// Since the router tree may be swapped at any given time, you MUST always copy the pointer locally -// to avoid inadvertently releasing Params to the wrong pool or worst, causing a deadlock by locking/unlocking the -// wrong Tree. +// Each tree as its own sync.Mutex that may be used to serialize write. Since the router tree may be swapped at any +// given time, you MUST always copy the pointer locally to avoid inadvertently causing a deadlock by locking/unlocking +// the wrong Tree. // // Good: // t := r.Tree() @@ -24,48 +27,39 @@ import ( // Dramatically bad, may cause deadlock // r.Tree().Lock() // defer r.Tree().Unlock() -// -// This principle also applies to the Lookup function, which requires releasing the Params slice, if not nil, by -// calling params.Free(tree). Always ensure that the Tree pointer passed as a parameter to params.Free is the same -// as the one passed to the Lookup function. type Tree struct { - pp sync.Pool - np sync.Pool + ctx sync.Pool nodes atomic.Pointer[[]*node] + mws []MiddlewareFunc sync.Mutex maxParams atomic.Uint32 maxDepth atomic.Uint32 - saveRoute bool } -// Handler registers a new handler for the given method and path. This function return an error if the route +// Handle registers a new handler for the given method and path. This function return an error if the route // is already registered or conflict with another. It's perfectly safe to add a new handler while the tree is in use // for serving requests. However, this function is NOT thread safe and should be run serially, along with // all other Tree's APIs. To override an existing route, use Update. -func (t *Tree) Handler(method, path string, handler Handler) error { +func (t *Tree) Handle(method, path string, handler HandlerFunc) error { p, catchAllKey, n, err := parseRoute(path) if err != nil { return err } - if t.saveRoute { - n += 1 - } - - return t.insert(method, p, catchAllKey, uint32(n), handler) + return t.insert(method, p, catchAllKey, uint32(n), t.apply(handler)) } // Update override an existing handler for the given method and path. If the route does not exist, // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. However, this function is NOT thread safe and should be run serially, along with -// all other Tree's APIs. To add new handler, use Handler method. -func (t *Tree) Update(method, path string, handler Handler) error { +// all other Tree's APIs. To add new handler, use Handle method. +func (t *Tree) Update(method, path string, handler HandlerFunc) error { p, catchAllKey, _, err := parseRoute(path) if err != nil { return err } - return t.update(method, p, catchAllKey, handler) + return t.update(method, p, catchAllKey, t.apply(handler)) } // Remove delete an existing handler for the given method and path. If the route does not exist, the function @@ -78,7 +72,7 @@ func (t *Tree) Remove(method, path string) error { } if !t.remove(method, path) { - return ErrRouteNotFound + return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) } return nil @@ -86,10 +80,10 @@ func (t *Tree) Remove(method, path string) error { // Insert is not safe for concurrent use. The path must start by '/' and it's not validated. Use // parseRoute before. -func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler Handler) error { +func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler HandlerFunc) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. if method == "" { - return fmt.Errorf("http method is missing: %w", ErrInvalidRoute) + return fmt.Errorf("%w: http method is missing", ErrInvalidRoute) } var rootNode *node @@ -113,10 +107,10 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler // └── am // Create a new node from "st" reference and update the "te" (parent) reference to "st" node. if result.matched.isLeaf() { - if isCatchAll && result.matched.isCatchAll() { + if result.matched.isCatchAll() && isCatchAll { return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) } - return fmt.Errorf("route [%s] %s conflict: %w", method, result.matched.path, ErrRouteExist) + return fmt.Errorf("%w: new route [%s] %s conflict with %s", ErrRouteExist, method, appendCatchAll(path, catchAllKey), result.matched.path) } // The matched node can only be the result of a previous split and therefore has children. @@ -146,10 +140,6 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler // 3. Update the "te" (parent) reference to the new "s" node (we are swapping old "st" to new "s" node, first // char remain the same). - /* if isCatchAll { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - }*/ - keyCharsFromStartOfNodeFound := path[result.charsMatched-result.charsMatchedInNodeFound:] cPrefix := commonPrefix(keyCharsFromStartOfNodeFound, result.matched.key) suffixFromExistingEdge := strings.TrimPrefix(result.matched.key, cPrefix) @@ -194,10 +184,6 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler // 2. Recreate the "st" node and link it to it's existing children and the new "ify" node. // 3. Update the "te" (parent) node to the new "st" node. - /* if result.matched.isCatchAll() { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - }*/ - keySuffix := path[result.charsMatched:] if strings.HasPrefix(keySuffix, "{") && result.matched.isCatchAll() { @@ -259,21 +245,8 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler } } - /* if result.matched.isCatchAll() { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - }*/ - suffixFromExistingEdge := strings.TrimPrefix(result.matched.key, cPrefix) - // Rule: parent's of a node with {param} have only one node or are prefixed by a char (e.g /{param}) - /* if strings.HasPrefix(suffixFromExistingEdge, "{") { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - }*/ - keySuffix := path[result.charsMatched:] - // Rule: parent's of a node with {param} have only one node or are prefixed by a char (e.g /{param}) - /* if strings.HasPrefix(keySuffix, "{") { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - }*/ // No children, so no paramChild n1 := newNodeFromRef(keySuffix, handler, nil, nil, catchAllKey, -1, path) // inserted node @@ -301,21 +274,23 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler } // update is not safe for concurrent use. -func (t *Tree) update(method string, path, catchAllKey string, handler Handler) error { +func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFunc) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return fmt.Errorf("route [%s] %s is not registered: %w", method, path, ErrRouteNotFound) + return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) } result := t.search(nds[index], path) if !result.isExactMatch() || !result.matched.isLeaf() { - return fmt.Errorf("route [%s] %s is not registered: %w", method, path, ErrRouteNotFound) + return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) } - if catchAllKey != "" && len(result.matched.children) > 0 { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)[1:]) + if catchAllKey != result.matched.catchAllKey { + err := newConflictErr(method, path, catchAllKey, []string{result.matched.path}) + err.isUpdate = true + return err } // We are updating an existing node (could be a leaf or not). We only need to create a new node from @@ -431,27 +406,26 @@ func (t *Tree) remove(method, path string) bool { return true } -func (t *Tree) lookup(rootNode *node, path string, lazy bool) (n *node, params *Params, tsr bool) { +const ( + slashDelim = '/' + bracketDelim = '{' +) + +func (t *Tree) lookup(rootNode *node, path string, params *Params, skipNds *skippedNodes, lazy bool) (n *node, tsr bool) { if len(rootNode.children) == 0 { - return nil, nil, false + return nil, false } var ( charsMatched int charsMatchedInNodeFound int - paramCnt int - skpNds *skippedNodes + paramCnt uint32 ) current := rootNode.children[0].Load() + *skipNds = (*skipNds)[:0] - defer func() { - if skpNds != nil { - skpNds.free(t) - } - }() - -walk: +Walk: for charsMatched < len(path) { charsMatchedInNodeFound = 0 for i := 0; charsMatched < len(path); i++ { @@ -459,10 +433,10 @@ walk: break } - if current.key[i] != path[charsMatched] || path[charsMatched] == '{' { + if current.key[i] != path[charsMatched] || path[charsMatched] == bracketDelim { if current.key[i] == '{' { startPath := charsMatched - idx := strings.IndexByte(path[charsMatched:], '/') + idx := strings.IndexByte(path[charsMatched:], slashDelim) if idx > 0 { // There is another path segment (e.g. /foo/{bar}/baz) charsMatched += idx @@ -471,11 +445,11 @@ walk: charsMatched += len(path[charsMatched:]) } else { // segment is empty - break walk + break Walk } startKey := charsMatchedInNodeFound - idx = strings.IndexByte(current.key[startKey:], '/') + idx = strings.IndexByte(current.key[startKey:], slashDelim) if idx >= 0 { // -1 since on the next incrementation, if any, 'i' are going to be incremented i += idx - 1 @@ -487,18 +461,14 @@ walk: } if !lazy { - if params == nil { - params = t.newParams() - } paramCnt++ - // :n where n > 0 *params = append(*params, Param{Key: current.key[startKey+1 : charsMatchedInNodeFound-1], Value: path[startPath:charsMatched]}) } continue } - break walk + break Walk } charsMatched++ @@ -527,67 +497,46 @@ walk: } if current.paramChildIndex >= 0 || current.isCatchAll() { - if skpNds == nil { - skpNds = t.newSkippedNods() - } - *skpNds = append(*skpNds, skippedNode{current, charsMatched}) + *skipNds = append(*skipNds, skippedNode{current, charsMatched}) paramCnt = 0 } current = current.children[idx].Load() } } - hasSkpNds := skpNds != nil && len(*skpNds) > 0 + hasSkpNds := len(*skipNds) > 0 if !current.isLeaf() { if hasSkpNds { - goto backtrack + goto Backtrack } - if params != nil { - params.Free(t) - } - return nil, nil, false + return nil, false } if charsMatched == len(path) { if charsMatchedInNodeFound == len(current.key) { - // Exact match, note that if we match a wildcard node, the param value is always '/' - if !lazy && (t.saveRoute || current.isCatchAll()) { - if params == nil { - params = t.newParams() - } - - if t.saveRoute { - *params = append(*params, Param{Key: RouteKey, Value: current.path}) - } - - if current.isCatchAll() { - *params = append(*params, Param{Key: current.catchAllKey, Value: path[charsMatched:]}) - } - - return current, params, false + // Exact match, note that if we match a catch all node + if !lazy && current.isCatchAll() { + *params = append(*params, Param{Key: current.catchAllKey, Value: path[charsMatched:]}) + return current, false } - return current, params, false - } else if charsMatchedInNodeFound < len(current.key) { + return current, false + } + if charsMatchedInNodeFound < len(current.key) { // Key end mid-edge // Tsr recommendation: add an extra trailing slash (got an exact match) - if !tsr { remainingSuffix := current.key[charsMatchedInNodeFound:] - tsr = len(remainingSuffix) == 1 && remainingSuffix[0] == '/' + tsr = len(remainingSuffix) == 1 && remainingSuffix[0] == slashDelim } if hasSkpNds { - goto backtrack + goto Backtrack } - if params != nil { - params.Free(t) - } - - return nil, nil, tsr + return nil, tsr } } @@ -595,71 +544,53 @@ walk: if charsMatched < len(path) && charsMatchedInNodeFound == len(current.key) { if current.isCatchAll() { if !lazy { - if params == nil { - params = t.newParams() - } *params = append(*params, Param{Key: current.catchAllKey, Value: path[charsMatched:]}) - if t.saveRoute { - *params = append(*params, Param{Key: RouteKey, Value: current.path}) - } - return current, params, false + return current, false } // Same as exact match, no tsr recommendation - return current, params, false + return current, false } // Tsr recommendation: remove the extra trailing slash (got an exact match) if !tsr { remainingKeySuffix := path[charsMatched:] - tsr = len(remainingKeySuffix) == 1 && remainingKeySuffix[0] == '/' + tsr = len(remainingKeySuffix) == 1 && remainingKeySuffix[0] == slashDelim } if hasSkpNds { - goto backtrack + goto Backtrack } - return nil, nil, tsr + return nil, tsr } // Finally incomplete match to middle of ege -backtrack: +Backtrack: if hasSkpNds { - skipped := skpNds.pop() + skipped := skipNds.pop() if skipped.node.paramChildIndex < 0 { + // skipped is catch all current = skipped.node - if params != nil { - *params = (*params)[:len(*params)-paramCnt] - } + *params = (*params)[:len(*params)-int(paramCnt)] + if !lazy { - if params == nil { - params = t.newParams() - } *params = append(*params, Param{Key: current.catchAllKey, Value: path[skipped.pathIndex:]}) - if t.saveRoute { - *params = append(*params, Param{Key: RouteKey, Value: current.path}) - } - return current, params, false + + return current, false } - // Same as exact match, no tsr recommendation - return current, params, false + return current, false } current = skipped.node.children[skipped.node.paramChildIndex].Load() - if params != nil { - *params = (*params)[:len(*params)-paramCnt] - } + *params = (*params)[:len(*params)-int(paramCnt)] charsMatched = skipped.pathIndex paramCnt = 0 - goto walk - } - - if params != nil { - params.Free(t) + goto Walk } - return nil, nil, tsr + return nil, tsr } func (t *Tree) search(rootNode *node, path string) searchResult { @@ -710,6 +641,18 @@ STOP: } } +func (t *Tree) allocateContext() *context { + params := make(Params, 0, t.maxParams.Load()) + skipNds := make(skippedNodes, 0, t.maxDepth.Load()) + return &context{ + params: ¶ms, + skipNds: &skipNds, + // This is a read only value, no reset, it's always the + // owner of the pool. + tree: t, + } +} + // addRoot is not safe for concurrent use. func (t *Tree) addRoot(n *node) { nds := *t.nodes.Load() @@ -764,14 +707,10 @@ func (t *Tree) updateMaxDepth(max uint32) { } } -func (t *Tree) load() []*node { - return *t.nodes.Load() -} - -func (t *Tree) newParams() *Params { - return t.pp.Get().(*Params) -} - -func (t *Tree) newSkippedNods() *skippedNodes { - return t.np.Get().(*skippedNodes) +func (t *Tree) apply(h HandlerFunc) HandlerFunc { + m := h + for i := len(t.mws) - 1; i >= 0; i-- { + m = t.mws[i](m) + } + return m }