diff --git a/LICENSE.txt b/LICENSE.txt index 261eeb9..bb9595b 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2022 Sylvain Müller Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 7db9617..65ca29a 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,10 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/tigerwill90/fox)](https://goreportcard.com/report/github.com/tigerwill90/fox) [![codecov](https://codecov.io/gh/tigerwill90/fox/branch/master/graph/badge.svg?token=09nfd7v0Bl)](https://codecov.io/gh/tigerwill90/fox) # Fox -Fox is a lightweight high performance HTTP request router for [Go](https://go.dev/). The main difference with other routers is +Fox is a zero allocation, lightweight, high performance HTTP request router for [Go](https://go.dev/). The main difference with other routers is that it supports **mutation on its routing tree while handling request concurrently**. Internally, Fox use a [Concurrent Radix Tree](https://github.com/npgall/concurrent-trees/blob/master/documentation/TreeDesign.md) that support **lock-free -reads** while allowing **concurrent writes**. - -The router tree is optimized for high-concurrency and high performance reads, and low-concurrency write. Fox has a small memory footprint, and -in many case, it does not do a single heap allocation while handling request. +reads** while allowing **concurrent writes**. The router tree is optimized for high-concurrency and high performance reads, and low-concurrency write. Fox supports various use cases, but it is especially designed for applications that require changes at runtime to their routing structure based on user input, configuration changes, or other runtime events. @@ -24,13 +21,14 @@ request! **Wildcard pattern:** Route can be registered using wildcard parameters. The matched path segment can then be easily retrieved by name. Due to Fox design, wildcard route are cheap and scale really well. -**Detect panic:** You can register a fallback handler that is fire in case of panics occurring during handling an HTTP request. +**Detect panic:** Comes with a ready-to-use, efficient Recovery middleware that gracefully handles panics. **Get the current route:** You can easily retrieve the route of the matched request. This actually makes it easier to integrate -observability middleware like open telemetry (disable by default). +observability middleware like open telemetry. -**Only explicit matches:** Inspired from [httprouter](https://github.com/julienschmidt/httprouter), a request can only match -exactly one or no route. As a result there are no unintended matches, and it also encourages good RESTful api design. +**Only explicit matches:** A request can only match exactly one route or no route at all. Fox strikes a balance between routing flexibility, +performance and clarity by enforcing clear priority rules, ensuring that there are no unintended matches and maintaining high performance +even for complex routing pattern. **Redirect trailing slashes:** Inspired from [httprouter](https://github.com/julienschmidt/httprouter), the router automatically redirects the client, at no extra cost, if another route match with or without a trailing slash (disable by default). @@ -51,40 +49,38 @@ go get -u github.com/tigerwill90/fox package main import ( - "fmt" "github.com/tigerwill90/fox" "log" "net/http" ) -var WelcomeHandler = fox.HandlerFunc(func(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprint(w, "Welcome!\n") -}) - -type HelloHandler struct{} +type Greeting struct { + Say string +} -func (h *HelloHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprintf(w, "Hello %s\n", params.Get("name")) +func (h *Greeting) Greet(c fox.Context) error { + return c.String(http.StatusOK, "%s %s\n", h.Say, c.Param("name")) } func main() { - r := fox.New() - - Must(r.Handler(http.MethodGet, "/", WelcomeHandler)) - Must(r.Handler(http.MethodGet, "/hello/{name}", new(HelloHandler))) + r := fox.New(fox.DefaultOptions()) - log.Fatalln(http.ListenAndServe(":8080", r)) -} - -func Must(err error) { + err := r.Handle(http.MethodGet, "/", func(c fox.Context) error { + return c.String(http.StatusOK, "Welcome\n") + }) if err != nil { panic(err) } + + h := Greeting{Say: "Hello"} + r.MustHandle(http.MethodGet, "/hello/{name}", h.Greet) + + log.Fatalln(http.ListenAndServe(":8080", r)) } ```` #### Error handling Since new route may be added at any given time, Fox, unlike other router, does not panic when a route is malformed or conflicts with another. -Instead, it returns the following error values +Instead, it returns the following error values: ```go ErrRouteExist = errors.New("route already registered") ErrRouteConflict = errors.New("route conflict") @@ -101,11 +97,34 @@ if errors.Is(err, fox.ErrRouteConflict) { } ``` +In addition, Fox also provides a centralized way to handle errors that may occur during the execution of a HandlerFunc. + +````go +var MyCustomError = errors.New("my custom error") + +r := fox.New( + fox.WithRouteError(func(c fox.Context, err error) { + if !c.Writer().Written() { + if errors.Is(err, MyCustomError) { + http.Error(c.Writer(), err.Error(), http.StatusInternalServerError) + return + } + http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + }), +) + +r.MustHandle(http.MethodGet, "/hello/{name}", func(c fox.Context) error { + return MyCustomError +}) +```` + #### Named parameters -A route can be defined using placeholder (e.g `{name}`). The values are accessible via `fox.Params`, which is just a slice of `fox.Param`. -The `Get` method is a helper to retrieve the value using the placeholder name. +A route can be defined using placeholder (e.g `{name}`). The matching segment are recorder into the `fox.Params` slice accessible +via `fox.Context`. The `Param` and `Get` methods are helpers to retrieve the value using the placeholder name. -``` +```` Pattern /avengers/{name} /avengers/ironman match @@ -117,12 +136,12 @@ Pattern /users/uuid:{id} /users/uuid:123 match /users/uuid no match -``` +```` -### Catch all parameter +#### Catch all parameter Catch-all parameters can be used to match everything at the end of a route. The placeholder start with `*` followed by a regular named parameter (e.g. `*{name}`). -``` +```` Pattern /src/*{filepath} /src/ match @@ -134,21 +153,38 @@ Patter /src/file=*{path} /src/file= match /src/file=config.txt match /src/file=/dir/config.txt match -``` +```` -#### Warning about params slice -`fox.Params` slice is freed once ServeHTTP returns and may be reused later to save resource. Therefore, if you need to hold `fox.Params` -longer, use the `Clone` methods. -```go -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { - p := params.Clone() - go func(){ - time.Sleep(1 * time.Second) - log.Println(p.Get("name")) // Safe - }() - _, _ = fmt.Fprintf(w, "Hello %s\n", params.Get("name")) +#### Priority rules +Routes are prioritized based on specificity, with static segments taking precedence over wildcard segments. +A wildcard segment (named parameter or catch all) can only overlap with static segments, for the same HTTP method. +For instance, `GET /users/{id}` and `GET /users/{name}/profile` cannot coexist, as the `{id}` and `{name}` segments +are overlapping. These limitations help to minimize the number of branches that need to be evaluated in order to find +the right match, thereby maintaining high-performance routing. + +For example, the followings route are allowed: +```` +GET /*{filepath} +GET /users/{id} +GET /users/{id}/emails +GET /users/{id}/{actions} +POST /users/{name}/emails +```` + +#### Warning about context +The `fox.Context` instance is freed once the request handler function returns to optimize resource allocation. +If you need to retain `fox.Context` or `fox.Params` beyond the scope of the handler, use the `Clone` methods. +````go +func Hello(c fox.Context) error { + cc := c.Clone() + // cp := c.Params().Clone() + go func() { + time.Sleep(2 * time.Second) + log.Println(cc.Param("name")) // Safe + }() + return c.String(http.StatusOK, "Hello %s\n", c.Param("name")) } -``` +```` ## Concurrency Fox implements a [Concurrent Radix Tree](https://github.com/npgall/concurrent-trees/blob/master/documentation/TreeDesign.md) that supports **lock-free** @@ -158,7 +194,7 @@ into a **patch**, which is then applied to the tree in a **single atomic operati For example, here we are inserting the new key `toast` into to the tree which require an existing node to be split:

- +

When traversing the tree during a patch, reading threads will either see the **old version** or the **new version** of the (sub-)tree, but both version are @@ -178,11 +214,12 @@ As such threads that route requests should never encounter latency due to ongoin In this example, the handler for `routes/{action}` allow to dynamically register, update and remove handler for the given route and method. Thanks to Fox's design, those actions are perfectly safe and may be executed concurrently. -```go +````go package main import ( "encoding/json" + "errors" "fmt" "github.com/tigerwill90/fox" "log" @@ -190,15 +227,10 @@ import ( "strings" ) -type ActionHandler struct { - fox *fox.Router -} - -func (h *ActionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { +func Action(c fox.Context) error { var data map[string]string - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + if err := json.NewDecoder(c.Request().Body).Decode(&data); err != nil { + return fox.NewHTTPError(http.StatusBadRequest, err) } method := strings.ToUpper(data["method"]) @@ -206,51 +238,43 @@ func (h *ActionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, params text := data["text"] if path == "" || method == "" { - http.Error(w, "missing method or path", http.StatusBadRequest) - return + return fox.NewHTTPError(http.StatusBadRequest, errors.New("missing method or path")) } var err error - action := params.Get("action") + action := c.Param("action") switch action { case "add": - err = h.fox.Handler(method, path, fox.HandlerFunc(func(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprintln(w, text) - })) + err = c.Fox().Handle(method, path, func(c fox.Context) error { + return c.String(http.StatusOK, text) + }) case "update": - err = h.fox.Update(method, path, fox.HandlerFunc(func(w http.ResponseWriter, r *http.Request, params fox.Params) { - _, _ = fmt.Fprintln(w, text) - })) + err = c.Fox().Update(method, path, func(c fox.Context) error { + return c.String(http.StatusOK, text) + }) case "delete": - err = h.fox.Remove(method, path) + err = c.Fox().Remove(method, path) default: - http.Error(w, fmt.Sprintf("action %q is not allowed", action), http.StatusBadRequest) - return + return fox.NewHTTPError(http.StatusBadRequest, fmt.Errorf("action %q is not allowed", action)) } if err != nil { - http.Error(w, err.Error(), http.StatusConflict) - return + return fox.NewHTTPError(http.StatusConflict, err) } - _, _ = fmt.Fprintf(w, "%s route [%s] %s: success\n", action, method, path) + return c.String(http.StatusOK, "%s route [%s] %s: success\n", action, method, path) } func main() { r := fox.New() - Must(r.Handler(http.MethodPost, "/routes/{action}", &ActionHandler{fox: r})) + r.MustHandle(http.MethodPost, "/routes/{action}", Action) log.Fatalln(http.ListenAndServe(":8080", r)) } - -func Must(err error) { - if err != nil { - panic(err) - } -} -``` +```` #### Tree swapping -Fox also enables you to replace the entire tree in a single atomic operation using the `Use` and `Swap` methods. +Fox also enables you to replace the entire tree in a single atomic operation using the `Swap` methods. Note that router's options apply automatically on the new tree. + ````go package main @@ -258,7 +282,6 @@ import ( "fox-by-example/db" "github.com/tigerwill90/fox" "html/template" - "io" "log" "net/http" "strings" @@ -269,18 +292,23 @@ type HtmlRenderer struct { Template template.HTML } -func (h *HtmlRenderer) ServeHTTP(w http.ResponseWriter, r *http.Request, params fox.Params) { - log.Printf("matched route: %s", params.Get(fox.RouteKey)) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, _ = io.Copy(w, strings.NewReader(string(h.Template))) +func (h *HtmlRenderer) Render(c fox.Context) error { + log.Printf("matched handler path: %s", c.Path()) + return c.Stream( + http.StatusInternalServerError, + fox.MIMETextHTMLCharsetUTF8, + strings.NewReader(string(h.Template)), + ) } func main() { - r := fox.New(fox.WithSaveMatchedRoute(true)) + r := fox.New() routes := db.GetRoutes() + for _, rte := range routes { - Must(r.Handler(rte.Method, rte.Path, &HtmlRenderer{Template: rte.Template})) + h := HtmlRenderer{Template: rte.Template} + r.MustHandle(rte.Method, rte.Path, h.Render) } go Reload(r) @@ -293,22 +321,17 @@ func Reload(r *fox.Router) { routes := db.GetRoutes() tree := r.NewTree() for _, rte := range routes { - if err := tree.Handler(rte.Method, rte.Path, &HtmlRenderer{Template: rte.Template}); err != nil { + h := HtmlRenderer{Template: rte.Template} + if err := tree.Handle(rte.Method, rte.Path, h.Render); err != nil { log.Printf("error reloading route: %s\n", err) continue } } - // Replace the currently in-use routing tree with the new provided. - r.Use(tree) + // Swap the currently in-use routing tree with the new provided. + r.Swap(tree) log.Println("route reloaded") } } - -func Must(err error) { - if err != nil { - panic(err) - } -} ```` #### Advanced usage: consistent view updates @@ -321,21 +344,24 @@ is already registered for the provided method and path. By locking the `Tree`, t atomicity, as it prevents other threads from modifying the tree between the lookup and the write operation. Note that all read operation on the tree remain lock-free. ````go -func Upsert(t *fox.Tree, method, path string, handler fox.Handler) error { +func Upsert(t *fox.Tree, method, path string, handler fox.HandlerFunc) error { t.Lock() defer t.Unlock() if fox.Has(t, method, path) { return t.Update(method, path, handler) } - return t.Handler(method, path, handler) + return t.Handle(method, path, handler) } ```` #### Concurrent safety and proper usage of Tree APIs -Some important consideration to keep in mind when using `Tree` API. Each instance as its own `sync.Mutex` and `sync.Pool` -that may be used to serialize write and reduce memory allocation. Since the router tree may be swapped at any -given time, you **MUST always copy the pointer locally** to avoid inadvertently releasing Params to the wrong pool -or worst, causing a deadlock by locking/unlocking the wrong `Tree`. +When working with the `Tree` API, it's important to keep some considerations in mind. Each instance has its +own `sync.Mutex` that can be used to serialize writes. However, unlike the router API, the lower-level `Tree` API +does not automatically lock the tree when writing to it. Therefore, it is the user's responsibility to ensure +all writes are executed serially. + +Moreover, since the router tree may be swapped at any given time, you MUST always copy the pointer locally to +avoid inadvertently causing a deadlock by locking/unlocking the wrong `Tree`. ````go // Good @@ -343,75 +369,148 @@ t := r.Tree() t.Lock() defer t.Unlock() -// Dramatically bad, may cause deadlock: +// Dramatically bad, may cause deadlock r.Tree().Lock() defer r.Tree().Unlock() + +// Dramatically bad, may cause deadlock +func handle(c fox.Context) error { + c.Fox().Tree().Lock() + defer c.Fox().Tree().Unlock() + return nil +} ```` -This principle also applies to the `fox.Lookup` function, which requires releasing the `fox.Params` slice by calling `params.Free(tree)`. -Always ensure that the `Tree` pointer passed as a parameter to `params.Free` is the same as the one passed to the `fox.Lookup` function. +Note that `fox.Context` carries a local copy of the `Tree` that is being used to serve the handler, thereby eliminating +the risk of deadlock when using the `Tree` within the context. +````go +// Ok +func handle(c fox.Context) error { + c.Tree().Lock() + defer c.Tree().Unlock() + return nil +} +```` ## Working with http.Handler Fox itself implements the `http.Handler` interface which make easy to chain any compatible middleware before the router. Moreover, the router -provides convenient `fox.WrapF` and `fox.WrapH` adapter to be use with `http.Handler`. Named and catch all parameters are forwarded via the -request context +provides convenient `fox.WrapF`, `fox.WrapH` and `fox.WrapM` adapter to be use with `http.Handler`. + +Wrapping an `http.Handler` ```go -_ = r.Handler(http.MethodGet, "/users/{id}", fox.WrapF(func(w http.ResponseWriter, r *http.Request) { - params := fox.ParamsFromContext(r.Context()) - _, _ = fmt.Fprintf(w, "user id: %s\n", params.Get("id")) -})) +articles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "get articles") +}) + +r := fox.New(fox.DefaultOptions()) +r.MustHandle(http.MethodGet, "/articles", fox.WrapH(httpRateLimiter.RateLimit(articles))) ``` +Wrapping an `http.Handler` compatible middleware +````go +r := fox.New(fox.DefaultOptions(), fox.WithMiddleware(fox.WrapM(httpRateLimiter.RateLimit))) +r.MustHandle(http.MethodGet, "/articles/{id}", func(c fox.Context) error { + return c.String(http.StatusOK, "Article id: %s\n", c.Param("id")) +}) +```` + +## Middleware +Middlewares can be registered globally using the `fox.WithMiddleware` option. The example below demonstrates how +to create and apply automatically a simple logging middleware to all route. + +````go +package main + +import ( + "fmt" + "github.com/tigerwill90/fox" + "log" + "net/http" + "time" +) + +var logger = fox.MiddlewareFunc(func(next fox.HandlerFunc) fox.HandlerFunc { + return func(c fox.Context) error { + start := time.Now() + err := next(c) + msg := fmt.Sprintf("route: %s, latency: %s, status: %d, size: %d", + c.Path(), + time.Since(start), + c.Writer().Status(), + c.Writer().Size(), + ) + if err != nil { + msg += fmt.Sprintf(", error: %s", err) + } + log.Println(msg) + return err + } +}) + +func main() { + r := fox.New(fox.WithMiddleware(logger)) + + r.MustHandle(http.MethodGet, "/", func(c fox.Context) error { + resp, err := http.Get("https://api.coindesk.com/v1/bpi/currentprice.json") + if err != nil { + return fox.NewHTTPError(http.StatusInternalServerError) + } + defer resp.Body.Close() + return c.Stream(http.StatusOK, fox.MIMEApplicationJSON, resp.Body) + }) + + log.Fatalln(http.ListenAndServe(":8080", r)) +} +```` + ## Benchmark -The primary goal of Fox is to be a lightweight, high performance router which allow routes modification while in operation. -The following benchmarks attempt to compare Fox to various popular alternatives. Some are fully featured web framework, and other -are lightweight request router. This is based on [julienschmidt/go-http-routing-benchmark](https://github.com/julienschmidt/go-http-routing-benchmark) +The primary goal of Fox is to be a lightweight, high performance router which allow routes modification at runtime. +The following benchmarks attempt to compare Fox to various popular alternatives, including both fully-featured web frameworks +and lightweight request routers. These benchmarks are based on the [julienschmidt/go-http-routing-benchmark](https://github.com/julienschmidt/go-http-routing-benchmark) repository. +Please note that these benchmarks should not be taken too seriously, as the comparison may not be entirely fair due to +the differences in feature sets offered by each framework. Performance should be evaluated in the context of your specific +use case and requirements. While Fox aims to excel in performance, it's important to consider the trade-offs and +functionality provided by different web frameworks and routers when making your selection. + ### Config ``` GOOS: Linux GOARCH: amd64 -GO: 1.19 +GO: 1.20 CPU: Intel(R) Core(TM) i9-9900K CPU @ 3.60GHz ``` ### Static Routes It is just a collection of random static paths inspired by the structure of the Go directory. It might not be a realistic URL-structure. -**GOMAXPROCS: 0** +**GOMAXPROCS: 1** ``` -BenchmarkDenco_StaticAll 352584 3350 ns/op 0 B/op 0 allocs/op -BenchmarkHttpRouter_StaticAll 159259 7400 ns/op 0 B/op 0 allocs/op -BenchmarkKocha_StaticAll 154405 7793 ns/op 0 B/op 0 allocs/op -BenchmarkFox_StaticAll 130474 8899 ns/op 0 B/op 0 allocs/op -BenchmarkHttpTreeMux_StaticAll 127754 9065 ns/op 0 B/op 0 allocs/op -BenchmarkGin_StaticAll 96139 12393 ns/op 0 B/op 0 allocs/op -BenchmarkBeego_StaticAll 10000 103464 ns/op 55264 B/op 471 allocs/op -BenchmarkGorillaMux_StaticAll 2307 501554 ns/op 113041 B/op 1099 allocs/op -BenchmarkMartini_StaticAll 1357 886524 ns/op 129210 B/op 2031 allocs/op -BenchmarkTraffic_StaticAll 990 1183413 ns/op 753608 B/op 14601 allocs/op -BenchmarkPat_StaticAll 972 1193521 ns/op 602832 B/op 12559 allocs/op +BenchmarkHttpRouter_StaticAll 161659 7570 ns/op 0 B/op 0 allocs/op +BenchmarkHttpTreeMux_StaticAll 132446 8836 ns/op 0 B/op 0 allocs/op +BenchmarkFox_StaticAll 102577 11348 ns/op 0 B/op 0 allocs/op +BenchmarkStdMux_StaticAll 91304 13382 ns/op 0 B/op 0 allocs/op +BenchmarkGin_StaticAll 78224 15433 ns/op 0 B/op 0 allocs/op +BenchmarkEcho_StaticAll 77923 15739 ns/op 0 B/op 0 allocs/op +BenchmarkBeego_StaticAll 10000 101094 ns/op 55264 B/op 471 allocs/op +BenchmarkGorillaMux_StaticAll 2283 525683 ns/op 113041 B/op 1099 allocs/op +BenchmarkMartini_StaticAll 1330 936928 ns/op 129210 B/op 2031 allocs/op +BenchmarkTraffic_StaticAll 1064 1140959 ns/op 753611 B/op 14601 allocs/op +BenchmarkPat_StaticAll 967 1230424 ns/op 602832 B/op 12559 allocs/op ``` -In this benchmark, Fox performs as well as `Gin`, `HttpTreeMux` and `HttpRouter` which are all Radix Tree based routers. An interesting fact is +In this benchmark, Fox performs as well as `Gin` and `Echo` which are both Radix Tree based routers. An interesting fact is that [HttpTreeMux](https://github.com/dimfeld/httptreemux) also support [adding route while serving request concurrently](https://github.com/dimfeld/httptreemux#concurrency). However, it takes a slightly different approach, by using an optional `RWMutex` that may not scale as well as Fox under heavy load. The next -test compare `HttpTreeMux`, `HttpTreeMux_SafeAddRouteFlag` (concurrent reads and writes), `HttpRouter` and `Fox` in parallel benchmark. +test compare `HttpTreeMux` with and without the `*SafeAddRouteFlag` (concurrent reads and writes) and `Fox` in parallel benchmark. **GOMAXPROCS: 16** ``` -Route: /progs/image_package4.out - -BenchmarkHttpRouter_StaticSingleParallel-16 211819790 5.640 ns/op 0 B/op 0 allocs/op -BenchmarkFox_StaticSingleParallel-16 157547185 7.418 ns/op 0 B/op 0 allocs/op -BenchmarkHttpTreeMux_StaticSingleParallel-16 154222639 7.774 ns/op 0 B/op 0 allocs/op -BenchmarkHttpTreeMux_SafeAddRouteFlag_StaticSingleParallel-16 29904204 38.52 ns/op 0 B/op 0 allocs/op - Route: all -BenchmarkHttpRouter_StaticAllParallel-16 1446759 832.1 ns/op 0 B/op 0 allocs/op -BenchmarkHttpTreeMux_StaticAllParallel-16 997074 1100 ns/op 0 B/op 0 allocs/op -BenchmarkFox_StaticAllParallel-16 1000000 1105 ns/op 0 B/op 0 allocs/op -BenchmarkHttpTreeMux_SafeAddRouteFlag_StaticAllParallel-16 197578 6017 ns/op 0 B/op 0 allocs/op +BenchmarkFox_StaticAll-16 99322 11369 ns/op 0 B/op 0 allocs/op +BenchmarkFox_StaticAllParallel-16 831354 1422 ns/op 0 B/op 0 allocs/op +BenchmarkHttpTreeMux_StaticAll-16 135560 8861 ns/op 0 B/op 0 allocs/op +BenchmarkHttpTreeMux_StaticAllParallel-16* 172714 6916 ns/op 0 B/op 0 allocs/op ``` As you can see, this benchmark highlight the cost of using higher synchronisation primitive like `RWMutex` to be able to register new route while handling requests. @@ -421,47 +520,45 @@ The following benchmarks measure the cost of some very basic operations. In the first benchmark, only a single route, containing a parameter, is loaded into the routers. Then a request for a URL matching this pattern is made and the router has to call the respective registered handler function. End. -**GOMAXPROCS: 0** +**GOMAXPROCS: 1** ``` -BenchmarkFox_Param 29995566 39.04 ns/op 0 B/op 0 allocs/op -BenchmarkGin_Param 30710918 39.08 ns/op 0 B/op 0 allocs/op -BenchmarkHttpRouter_Param 20026911 55.88 ns/op 32 B/op 1 allocs/op -BenchmarkDenco_Param 15964747 70.04 ns/op 32 B/op 1 allocs/op -BenchmarkKocha_Param 8392696 138.5 ns/op 56 B/op 3 allocs/op -BenchmarkHttpTreeMux_Param 4469318 265.6 ns/op 352 B/op 3 allocs/op -BenchmarkBeego_Param 2241368 530.9 ns/op 352 B/op 3 allocs/op -BenchmarkPat_Param 1788819 666.8 ns/op 512 B/op 10 allocs/op -BenchmarkGorillaMux_Param 1208638 995.1 ns/op 1024 B/op 8 allocs/op -BenchmarkTraffic_Param 606530 1700 ns/op 1848 B/op 21 allocs/op -BenchmarkMartini_Param 455662 2419 ns/op 1096 B/op 12 allocs/op +BenchmarkFox_Param 33024534 36.61 ns/op 0 B/op 0 allocs/op +BenchmarkEcho_Param 31472508 38.71 ns/op 0 B/op 0 allocs/op +BenchmarkGin_Param 25826832 52.88 ns/op 0 B/op 0 allocs/op +BenchmarkHttpRouter_Param 21230490 60.83 ns/op 32 B/op 1 allocs/op +BenchmarkHttpTreeMux_Param 3960292 280.4 ns/op 352 B/op 3 allocs/op +BenchmarkBeego_Param 2247776 518.9 ns/op 352 B/op 3 allocs/op +BenchmarkPat_Param 1603902 676.6 ns/op 512 B/op 10 allocs/op +BenchmarkGorillaMux_Param 1000000 1011 ns/op 1024 B/op 8 allocs/op +BenchmarkTraffic_Param 648986 1686 ns/op 1848 B/op 21 allocs/op +BenchmarkMartini_Param 485839 2446 ns/op 1096 B/op 12 allocs/op ``` Same as before, but now with multiple parameters, all in the same single route. The intention is to see how the routers scale with the number of parameters. **GOMAXPROCS: 0** ``` -BenchmarkGin_Param5 16470636 73.09 ns/op 0 B/op 0 allocs/op -BenchmarkFox_Param5 14716213 82.05 ns/op 0 B/op 0 allocs/op -BenchmarkHttpRouter_Param5 7614333 154.7 ns/op 160 B/op 1 allocs/op -BenchmarkDenco_Param5 6513253 179.5 ns/op 160 B/op 1 allocs/op -BenchmarkKocha_Param5 2073741 604.3 ns/op 440 B/op 10 allocs/op -BenchmarkHttpTreeMux_Param5 1801978 659.2 ns/op 576 B/op 6 allocs/op -BenchmarkBeego_Param5 1764513 669.1 ns/op 352 B/op 3 allocs/op -BenchmarkGorillaMux_Param5 657648 1578 ns/op 1088 B/op 8 allocs/op -BenchmarkPat_Param5 633555 1700 ns/op 800 B/op 24 allocs/op -BenchmarkTraffic_Param5 374895 2744 ns/op 2200 B/op 27 allocs/op -BenchmarkMartini_Param5 403650 2835 ns/op 1256 B/op 13 allocs/op - -BenchmarkGin_Param20 6136497 189.9 ns/op 0 B/op 0 allocs/op -BenchmarkFox_Param20 4187372 283.2 ns/op 0 B/op 0 allocs/op -BenchmarkHttpRouter_Param20 2536359 483.4 ns/op 640 B/op 1 allocs/op -BenchmarkDenco_Param20 2110105 567.7 ns/op 640 B/op 1 allocs/op -BenchmarkKocha_Param20 593958 1744 ns/op 1808 B/op 27 allocs/op -BenchmarkBeego_Param20 741110 1747 ns/op 352 B/op 3 allocs/op -BenchmarkHttpTreeMux_Param20 341913 3079 ns/op 3195 B/op 10 allocs/op -BenchmarkGorillaMux_Param20 282345 3671 ns/op 3196 B/op 10 allocs/op -BenchmarkMartini_Param20 210543 5222 ns/op 3619 B/op 15 allocs/op -BenchmarkPat_Param20 151778 7343 ns/op 4096 B/op 73 allocs/op -BenchmarkTraffic_Param20 113230 9989 ns/op 7847 B/op 47 allocs/op +BenchmarkFox_Param5 16608495 72.84 ns/op 0 B/op 0 allocs/op +BenchmarkGin_Param5 13098740 92.22 ns/op 0 B/op 0 allocs/op +BenchmarkEcho_Param5 12025460 96.33 ns/op 0 B/op 0 allocs/op +BenchmarkHttpRouter_Param5 8233530 148.1 ns/op 160 B/op 1 allocs/op +BenchmarkHttpTreeMux_Param5 1986019 616.9 ns/op 576 B/op 6 allocs/op +BenchmarkBeego_Param5 1836229 655.3 ns/op 352 B/op 3 allocs/op +BenchmarkGorillaMux_Param5 757936 1572 ns/op 1088 B/op 8 allocs/op +BenchmarkPat_Param5 645847 1724 ns/op 800 B/op 24 allocs/op +BenchmarkTraffic_Param5 424431 2729 ns/op 2200 B/op 27 allocs/op +BenchmarkMartini_Param5 424806 2772 ns/op 1256 B/op 13 allocs/op + + +BenchmarkGin_Param20 4636416 244.6 ns/op 0 B/op 0 allocs/op +BenchmarkFox_Param20 4667533 250.7 ns/op 0 B/op 0 allocs/op +BenchmarkEcho_Param20 4352486 277.1 ns/op 0 B/op 0 allocs/op +BenchmarkHttpRouter_Param20 2618958 455.2 ns/op 640 B/op 1 allocs/op +BenchmarkBeego_Param20 847029 1688 ns/op 352 B/op 3 allocs/op +BenchmarkHttpTreeMux_Param20 369500 2972 ns/op 3195 B/op 10 allocs/op +BenchmarkGorillaMux_Param20 318134 3561 ns/op 3195 B/op 10 allocs/op +BenchmarkMartini_Param20 223070 5117 ns/op 3619 B/op 15 allocs/op +BenchmarkPat_Param20 157380 7442 ns/op 4094 B/op 73 allocs/op +BenchmarkTraffic_Param20 119677 9864 ns/op 7847 B/op 47 allocs/op ``` Now let's see how expensive it is to access a parameter. The handler function reads the value (by the name of the parameter, e.g. with a map @@ -469,53 +566,48 @@ lookup; depends on the router) and writes it to `/dev/null` **GOMAXPROCS: 0** ``` -BenchmarkFox_ParamWrite 21061758 56.96 ns/op 0 B/op 0 allocs/op -BenchmarkGin_ParamWrite 17973256 66.54 ns/op 0 B/op 0 allocs/op -BenchmarkHttpRouter_ParamWrite 15953065 74.64 ns/op 32 B/op 1 allocs/op -BenchmarkDenco_ParamWrite 12553562 89.93 ns/op 32 B/op 1 allocs/op -BenchmarkKocha_ParamWrite 7356948 156.7 ns/op 56 B/op 3 allocs/op -BenchmarkHttpTreeMux_ParamWrite 4075486 286.4 ns/op 352 B/op 3 allocs/op -BenchmarkBeego_ParamWrite 2126341 567.4 ns/op 360 B/op 4 allocs/op -BenchmarkPat_ParamWrite 1197910 996.5 ns/op 936 B/op 14 allocs/op -BenchmarkGorillaMux_ParamWrite 1139376 1048 ns/op 1024 B/op 8 allocs/op -BenchmarkTraffic_ParamWrite 496440 2057 ns/op 2272 B/op 25 allocs/op -BenchmarkMartini_ParamWrite 398594 2799 ns/op 1168 B/op 16 allocs/op +BenchmarkFox_ParamWrite 16707409 72.53 ns/op 0 B/op 0 allocs/op +BenchmarkHttpRouter_ParamWrite 16478174 73.30 ns/op 32 B/op 1 allocs/op +BenchmarkGin_ParamWrite 15828385 75.73 ns/op 0 B/op 0 allocs/op +BenchmarkEcho_ParamWrite 13187766 95.18 ns/op 8 B/op 1 allocs/op +BenchmarkHttpTreeMux_ParamWrite 4132832 279.9 ns/op 352 B/op 3 allocs/op +BenchmarkBeego_ParamWrite 2172572 554.3 ns/op 360 B/op 4 allocs/op +BenchmarkPat_ParamWrite 1200334 996.8 ns/op 936 B/op 14 allocs/op +BenchmarkGorillaMux_ParamWrite 1000000 1005 ns/op 1024 B/op 8 allocs/op +BenchmarkMartini_ParamWrite 454255 2667 ns/op 1168 B/op 16 allocs/op +BenchmarkTraffic_ParamWrite 511766 2021 ns/op 2272 B/op 25 allocs/op ``` In those micro benchmarks, we can see that `Fox` scale really well, even with long wildcard routes. Like `Gin`, this router reuse the -data structure (e.g. `fox.Params` slice) containing the matching parameters in order to remove completely heap allocation. We can also -notice that there is a very small overhead comparing to `Gin` when the number of parameters scale. This is due to the fact that every tree's node -in Fox are `atomic.Pointer` and that traversing the tree require to load the underlying node pointer atomically. Despite that, even -with 20 parameters, the performance of Fox is still better than most other contender. +data structure (e.g. `fox.Context` slice) containing the matching parameters in order to remove completely heap allocation. ### Github Finally, this benchmark execute a request for each GitHub API route (203 routes). **GOMAXPROCS: 0** ``` -BenchmarkGin_GithubAll 68384 17425 ns/op 0 B/op 0 allocs/op -BenchmarkFox_GithubAll 67162 17631 ns/op 0 B/op 0 allocs/op -BenchmarkHttpRouter_GithubAll 44085 27449 ns/op 13792 B/op 167 allocs/op -BenchmarkDenco_GithubAll 35019 33651 ns/op 20224 B/op 167 allocs/op -BenchmarkKocha_GithubAll 19186 62243 ns/op 23304 B/op 843 allocs/op -BenchmarkHttpTreeMuxSafeAddRoute_GithubAll 14907 79919 ns/op 65856 B/op 671 allocs/op -BenchmarkHttpTreeMux_GithubAll 14952 80280 ns/op 65856 B/op 671 allocs/op -BenchmarkBeego_GithubAll 9712 136414 ns/op 71456 B/op 609 allocs/op -BenchmarkTraffic_GithubAll 637 1824477 ns/op 819052 B/op 14114 allocs/op -BenchmarkMartini_GithubAll 572 2042852 ns/op 231419 B/op 2731 allocs/op -BenchmarkGorillaMux_GithubAll 562 2110880 ns/op 199683 B/op 1588 allocs/op -BenchmarkPat_GithubAll 550 2117715 ns/op 1410624 B/op 22515 allocs/op +BenchmarkFox_GithubAll 63984 18555 ns/op 0 B/op 0 allocs/op +BenchmarkEcho_GithubAll 49312 23353 ns/op 0 B/op 0 allocs/op +BenchmarkGin_GithubAll 48422 24926 ns/op 0 B/op 0 allocs/op +BenchmarkHttpRouter_GithubAll 45706 26818 ns/op 14240 B/op 171 allocs/op +BenchmarkHttpTreeMux_GithubAll 14731 80133 ns/op 67648 B/op 691 allocs/op +BenchmarkBeego_GithubAll 7692 137926 ns/op 72929 B/op 625 allocs/op +BenchmarkTraffic_GithubAll 636 1916586 ns/op 845114 B/op 14634 allocs/op +BenchmarkMartini_GithubAll 530 2205947 ns/op 238546 B/op 2813 allocs/op +BenchmarkGorillaMux_GithubAll 529 2246380 ns/op 203844 B/op 1620 allocs/op +BenchmarkPat_GithubAll 424 2899405 ns/op 1843501 B/op 29064 allocs/op ``` ## Road to v1 - [x] [Update route syntax](https://github.com/tigerwill90/fox/pull/10#issue-1643728309) @v0.6.0 -- [ ] [Route overlapping](https://github.com/tigerwill90/fox/pull/9#issue-1642887919) @v0.7.0 -- [ ] Collect feedback and polishing +- [x] [Route overlapping](https://github.com/tigerwill90/fox/pull/9#issue-1642887919) @v0.7.0 +- [ ] Improving performance and polishing ## Contributions This project aims to provide a lightweight, high performance and easy to use http router. It purposely has a limited set of features and exposes a relatively low-level api. -The intention behind these choices is that it can serve as a building block for more "batteries included" frameworks. Feature requests and PRs along these lines are welcome. +The intention behind these choices is that it can serve as a building block for implementing your own "batteries included" frameworks. Feature requests and PRs along these lines are welcome. ## Acknowledgements - [npgall/concurrent-trees](https://github.com/npgall/concurrent-trees): Fox design is largely inspired from Niall Gallagher's Concurrent Trees design. -- [julienschmidt/httprouter](https://github.com/julienschmidt/httprouter): a lot of feature that implements Fox are inspired from Julien Schmidt's router. \ No newline at end of file +- [julienschmidt/httprouter](https://github.com/julienschmidt/httprouter): some feature that implements Fox are inspired from Julien Schmidt's router. Most notably, +this package uses the optimized [httprouter.Cleanpath](https://github.com/julienschmidt/httprouter/blob/master/path.go) function. diff --git a/assets/tree-apply-patch.png b/assets/tree-apply-patch.png deleted file mode 100644 index aaf894b..0000000 Binary files a/assets/tree-apply-patch.png and /dev/null differ diff --git a/context.go b/context.go new file mode 100644 index 0000000..6bfca08 --- /dev/null +++ b/context.go @@ -0,0 +1,275 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + +package fox + +import ( + "fmt" + "io" + "net/http" + "net/url" +) + +// ContextCloser extends Context for manually created instances, adding a Close method +// to release resources after use. +type ContextCloser interface { + Context + Close() +} + +// Context represents the context of the current HTTP request. +// It provides methods to access request data and to write a response. +type Context interface { + // Done returns a channel that closes when the request's context is + // cancelled or times out. + Done() <-chan struct{} + // Request returns the current *http.Request. + Request() *http.Request + // SetRequest sets the *http.Request. + SetRequest(r *http.Request) + // Writer returns the ResponseWriter. + Writer() ResponseWriter + // SetWriter sets the ResponseWriter. + SetWriter(w ResponseWriter) + // Path returns the registered path for the handler. + Path() string + // Params returns a Params slice containing the matched + // wildcard parameters. + Params() Params + // Param retrieve a matching wildcard parameter by name. + Param(name string) string + // QueryParams parses the Request RawQuery and returns the corresponding values. + QueryParams() url.Values + // QueryParam returns the first query value associated with the given key. + QueryParam(name string) string + // String sends a formatted string with the specified status code. + String(code int, format string, values ...any) error + // Blob sends a byte slice with the specified status code and content type. + Blob(code int, contentType string, buf []byte) error + // Stream sends data from an io.Reader with the specified status code and content type. + Stream(code int, contentType string, r io.Reader) error + // Redirect sends an HTTP redirect response with the given status code and URL. + Redirect(code int, url string) error + // Clone returns a copy of the Context that is safe to use after the HandlerFunc returns. + Clone() Context + // Tree is a local copy of the Tree in use to serve the request. + Tree() *Tree + // Fox returns the Router in use to serve the request. + Fox() *Router +} + +// context holds request-related information and allows interaction with the ResponseWriter. +type context struct { + w ResponseWriter + req *http.Request + params *Params + skipNds *skippedNodes + + // tree at allocation (read-only, no reset) + tree *Tree + fox *Router + + cachedQuery url.Values + path string + rec recorder +} + +func (c *context) reset(fox *Router, w http.ResponseWriter, r *http.Request) { + c.rec.reset(w) + c.req = r + if r.ProtoMajor == 2 { + c.w = h2Writer{&c.rec} + } else { + c.w = h1Writer{&c.rec} + } + c.fox = fox + c.path = "" + c.cachedQuery = nil + *c.params = (*c.params)[:0] +} + +func (c *context) resetNil() { + c.req = nil + c.w = nil + c.fox = nil + c.path = "" + c.cachedQuery = nil + *c.params = (*c.params)[:0] +} + +// Request returns the *http.Request. +func (c *context) Request() *http.Request { + return c.req +} + +// SetRequest sets the *http.Request. +func (c *context) SetRequest(r *http.Request) { + c.req = r +} + +// Writer returns the ResponseWriter. +func (c *context) Writer() ResponseWriter { + return c.w +} + +// SetWriter sets the ResponseWriter. +func (c *context) SetWriter(w ResponseWriter) { + c.w = w +} + +// Done returns a channel that closes when the request's context is +// cancelled or times out. +func (c *context) Done() <-chan struct{} { + return c.req.Context().Done() +} + +// Params returns a Params slice containing the matched +// wildcard parameters. +func (c *context) Params() Params { + return *c.params +} + +// Param retrieve a matching wildcard segment by name. +// It's a helper for c.Params.Get(name). +func (c *context) Param(name string) string { + for _, p := range c.Params() { + if p.Key == name { + return p.Value + } + } + return "" +} + +// QueryParams parses RawQuery and returns the corresponding values. +// It's a helper for c.Request.URL.Query(). Note that the parsed +// result is cached. +func (c *context) QueryParams() url.Values { + c.req.URL.Query() + return c.getQueries() +} + +// QueryParam returns the first value associated with the given key. +// It's a helper for c.QueryParams().Get(name). +func (c *context) QueryParam(name string) string { + return c.getQueries().Get(name) +} + +// Path returns the registered path for the handler. +func (c *context) Path() string { + return c.path +} + +// String sends a formatted string with the specified status code. +func (c *context) String(code int, format string, values ...any) (err error) { + if c.w.Header().Get(HeaderContentType) == "" { + c.w.Header().Set(HeaderContentType, MIMETextPlainCharsetUTF8) + } + c.w.WriteHeader(code) + _, err = fmt.Fprintf(c.w, format, values...) + return +} + +// Blob sends a byte slice with the specified status code and content type. +func (c *context) Blob(code int, contentType string, buf []byte) (err error) { + c.w.Header().Set(HeaderContentType, contentType) + c.w.WriteHeader(code) + _, err = c.w.Write(buf) + return +} + +// Stream sends data from an io.Reader with the specified status code and content type. +func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { + c.w.Header().Set(HeaderContentType, contentType) + c.w.WriteHeader(code) + _, err = io.Copy(c.w, r) + return +} + +// Redirect sends an HTTP redirect response with the given status code and URL. +func (c *context) Redirect(code int, url string) error { + if code < http.StatusMultipleChoices || code > http.StatusPermanentRedirect { + return ErrInvalidRedirectCode + } + http.Redirect(c.w, c.req, url, code) + return nil +} + +// Tree is a local copy of the Tree in use to serve the request. +func (c *context) Tree() *Tree { + return c.tree +} + +// Fox returns the Router in use to serve the request. +func (c *context) Fox() *Router { + return c.fox +} + +// Clone returns a copy of the Context that is safe to use after the HandlerFunc returns. +func (c *context) Clone() Context { + cp := context{ + rec: c.rec, + req: c.req, + fox: c.fox, + tree: c.tree, + } + cp.rec.ResponseWriter = noopWriter{} + cp.w = &cp.rec + params := make(Params, len(*c.params)) + copy(params, *c.params) + cp.params = ¶ms + cp.cachedQuery = nil + return &cp +} + +// Close releases the context to be reused later. +func (c *context) Close() { + // Put back the context, if not extended more than max params or max depth, allowing + // the slice to naturally grow within the constraint. + if cap(*c.params) > int(c.tree.maxParams.Load()) || cap(*c.skipNds) > int(c.tree.maxDepth.Load()) { + return + } + c.tree.ctx.Put(c) +} + +func (c *context) getQueries() url.Values { + if c.cachedQuery == nil { + if c.req != nil { + c.cachedQuery = c.req.URL.Query() + } else { + c.cachedQuery = url.Values{} + } + } + return c.cachedQuery +} + +// WrapF is an adapter for wrapping http.HandlerFunc and returns a HandlerFunc function. +func WrapF(f http.HandlerFunc) HandlerFunc { + return func(c Context) error { + f.ServeHTTP(c.Writer(), c.Request()) + return nil + } +} + +// WrapH is an adapter for wrapping http.Handler and returns a HandlerFunc function. +func WrapH(h http.Handler) HandlerFunc { + return func(c Context) error { + h.ServeHTTP(c.Writer(), c.Request()) + return nil + } +} + +// WrapM is an adapter for wrapping http.Handler middleware and returns a +// MiddlewareFunc function. +func WrapM(m func(handler http.Handler) http.Handler) MiddlewareFunc { + return func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + var err error + adapter := m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err = next(c) + })) + adapter.ServeHTTP(c.Writer(), c.Request()) + return err + } + } +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..667ea7a --- /dev/null +++ b/context_test.go @@ -0,0 +1,147 @@ +package fox + +import ( + "bytes" + netcontext "context" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +func TestContext_QueryParams(t *testing.T) { + wantValues := url.Values{ + "a": []string{"b"}, + "c": []string{"d", "e"}, + } + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + req.URL.RawQuery = wantValues.Encode() + + c := newTestContextTree(New().Tree()) + c.req = req + values := c.QueryParams() + assert.Equal(t, wantValues, values) + assert.Equal(t, wantValues, c.cachedQuery) +} + +func TestContext_QueryParam(t *testing.T) { + wantValues := url.Values{ + "a": []string{"b"}, + "c": []string{"d", "e"}, + } + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + req.URL.RawQuery = wantValues.Encode() + + c := newTestContextTree(New().Tree()) + c.req = req + assert.Equal(t, "b", c.QueryParam("a")) + assert.Equal(t, "d", c.QueryParam("c")) + assert.Equal(t, wantValues, c.cachedQuery) +} + +func TestContext_Clone(t *testing.T) { + wantValues := url.Values{ + "a": []string{"b"}, + "c": []string{"d", "e"}, + } + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + req.URL.RawQuery = wantValues.Encode() + + c := newTextContextOnly(New(), httptest.NewRecorder(), req) + + buf := []byte("foo bar") + _, err := c.w.Write(buf) + require.NoError(t, err) + + cc := c.Clone() + assert.Equal(t, http.StatusOK, cc.Writer().Status()) + assert.Equal(t, len(buf), cc.Writer().Size()) + assert.Equal(t, wantValues, c.QueryParams()) + _, err = cc.Writer().Write([]byte("invalid")) + assert.ErrorIs(t, err, ErrDiscardedResponseWriter) +} + +func TestContext_Done(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + ctx, cancel := netcontext.WithCancel(netcontext.Background()) + cancel() + req = req.WithContext(ctx) + _, c := NewTestContext(httptest.NewRecorder(), req) + select { + case <-c.Done(): + require.ErrorIs(t, c.Request().Context().Err(), netcontext.Canceled) + case <-time.After(1): + t.FailNow() + } +} + +func TestContext_Redirect(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + require.NoError(t, c.Redirect(http.StatusTemporaryRedirect, "https://example.com/foo/bar")) + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "https://example.com/foo/bar", w.Header().Get(HeaderLocation)) +} + +func TestContext_Blob(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + buf := []byte("foobar") + require.NoError(t, c.Blob(http.StatusCreated, MIMETextPlain, buf)) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, MIMETextPlain, w.Header().Get(HeaderContentType)) + assert.Equal(t, buf, w.Body.Bytes()) + assert.Equal(t, len(buf), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} + +func TestContext_Stream(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + buf := []byte("foobar") + require.NoError(t, c.Stream(http.StatusCreated, MIMETextPlain, bytes.NewBuffer(buf))) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, MIMETextPlain, w.Header().Get(HeaderContentType)) + assert.Equal(t, buf, w.Body.Bytes()) + assert.Equal(t, len(buf), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} + +func TestContext_String(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + s := "foobar" + require.NoError(t, c.String(http.StatusCreated, s)) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, MIMETextPlainCharsetUTF8, w.Header().Get(HeaderContentType)) + assert.Equal(t, s, w.Body.String()) + assert.Equal(t, len(s), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} + +func TestContext_Writer(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + _, c := NewTestContext(w, r) + buf := []byte("foobar") + c.Writer().WriteHeader(http.StatusCreated) + assert.Equal(t, 0, c.Writer().Size()) + n, err := c.Writer().Write(buf) + require.NoError(t, err) + assert.Equal(t, len(buf), n) + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, http.StatusCreated, c.Writer().Status()) + assert.Equal(t, buf, w.Body.Bytes()) + assert.Equal(t, len(buf), c.Writer().Size()) + assert.True(t, c.Writer().Written()) +} diff --git a/error.go b/error.go index c83df1f..be52ae2 100644 --- a/error.go +++ b/error.go @@ -1,28 +1,36 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( "errors" "fmt" + "net/http" "strings" ) var ( - ErrRouteNotFound = errors.New("route not found") - ErrRouteExist = errors.New("route already registered") - ErrRouteConflict = errors.New("route conflict") - ErrInvalidRoute = errors.New("invalid route") + ErrRouteNotFound = errors.New("route not found") + ErrRouteExist = errors.New("route already registered") + ErrRouteConflict = errors.New("route conflict") + ErrInvalidRoute = errors.New("invalid route") + ErrDiscardedResponseWriter = errors.New("discarded response writer") + ErrInvalidRedirectCode = errors.New("invalid redirect code") ) type RouteConflictError struct { - err error - Method string - Path string - Matched []string + err error + Method string + Path string + Matched []string + isUpdate bool } func newConflictErr(method, path, catchAllKey string, matched []string) *RouteConflictError { if catchAllKey != "" { - path += "*" + catchAllKey + path += "*{" + catchAllKey + "}" } return &RouteConflictError{ Method: method, @@ -33,10 +41,51 @@ func newConflictErr(method, path, catchAllKey string, matched []string) *RouteCo } func (e *RouteConflictError) Error() string { - path := e.Path - return fmt.Sprintf("new route [%s] %s conflicts with %s", e.Method, path, strings.Join(e.Matched, ", ")) + if !e.isUpdate { + return e.insertError() + } + return e.updateError() +} + +func (e *RouteConflictError) insertError() string { + return fmt.Sprintf("%s: new route [%s] %s conflicts with %s", e.err, e.Method, e.Path, strings.Join(e.Matched, ", ")) +} + +func (e *RouteConflictError) updateError() string { + return fmt.Sprintf("wildcard conflict: updated route [%s] %s conflicts with %s", e.Method, e.Path, strings.Join(e.Matched, ", ")) + } func (e *RouteConflictError) Unwrap() error { return e.err } + +// HTTPError represents an HTTP error with a status code (HTTPErrorCode) +// and an optional error message. If no error message is provided, +// the default error message for the status code will be used. +type HTTPError struct { + Err error + Code int +} + +// Error returns the error message associated with the HTTPError, +// or the default error message for the status code if none is provided. +func (e HTTPError) Error() string { + if e.Err == nil { + return http.StatusText(e.Code) + } + return e.Err.Error() +} + +// NewHTTPError creates a new HTTPError with the given status code +// and an optional error message. +func NewHTTPError(code int, err ...error) HTTPError { + var e error + if len(err) > 0 { + e = err[0] + } + return HTTPError{ + Code: code, + Err: e, + } +} diff --git a/router.go b/fox.go similarity index 63% rename from router.go rename to fox.go index 7892641..2cba4a0 100644 --- a/router.go +++ b/fox.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -13,71 +17,60 @@ const verb = 4 var commonVerbs = [verb]string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete} -// Handler respond to an HTTP request. +// HandlerFunc is a function type that responds to an HTTP request. +// It enforces the same contract as http.Handler but provides additional feature +// like matched wildcard route segments via the Context type. The Context is freed once +// the HandlerFunc returns and may be reused later to save resources. If you need +// to hold the context longer, you have to copy it (see Clone method). // -// This interface enforce the same contract as http.Handler except that matched wildcard route segment -// are accessible via params. Params slice is freed once ServeHTTP returns and may be reused later to -// save resource. Therefore, if you need to hold params slice longer, you have to copy it (see Clone method). +// The function may return an error that can be propagated through the middleware +// chain and handled by the registered ErrorHandlerFunc, which is set using the +// WithRouteError option. // -// As for http.Handler interface, to abort a handler so the client sees an interrupted response, panic with -// the value http.ErrAbortHandler. -type Handler interface { - ServeHTTP(http.ResponseWriter, *http.Request, Params) -} +// Similar to http.Handler, to abort a HandlerFunc so the client sees an interrupted +// response, panic with the value http.ErrAbortHandler. +// +// HandlerFunc functions should be thread-safe, as they will be called concurrently. +type HandlerFunc func(c Context) error -// HandlerFunc is an adapter to allow the use of ordinary functions as HTTP handlers. If f is a function with the -// appropriate signature, HandlerFunc(f) is a Handler that calls f. -type HandlerFunc func(http.ResponseWriter, *http.Request, Params) +// MiddlewareFunc is a function type for implementing HandlerFunc middleware. +// The returned HandlerFunc usually wraps the input HandlerFunc, allowing you to perform operations +// before and/or after the wrapped HandlerFunc is executed. MiddlewareFunc functions should +// be thread-safe, as they will be called concurrently. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc -// ServerHTTP calls f(w, r, params) -func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, params Params) { - f(w, r, params) -} +// ErrorHandlerFunc is a function type that handles errors returned by a HandlerFunc. +// It receives the Context and the error returned by the HandlerFunc, allowing +// centralized error management and custom error handling. +type ErrorHandlerFunc func(c Context, err error) // Router is a lightweight high performance HTTP request router that support mutation on its routing tree // while handling request concurrently. type Router struct { - // User-configurable http.Handler which is called when no matching route is found. - // By default, http.NotFound is used. - notFound http.Handler - - // User-configurable http.Handler which is called when the request cannot be routed, - // but the same route exist for other methods. The "Allow" header it automatically set - // before calling the handler. Set HandleMethodNotAllowed to true to enable this option. By default, - // http.Error with http.StatusMethodNotAllowed is used. - methodNotAllowed http.Handler - - // Register a function to handle panics recovered from http handlers. - panicHandler func(http.ResponseWriter, *http.Request, interface{}) - - tree atomic.Pointer[Tree] - - // If enabled, fox return a 405 Method Not Allowed instead of 404 Not Found when the route exist for another http verb. + noRoute HandlerFunc + noMethod HandlerFunc + errRoute ErrorHandlerFunc + tree atomic.Pointer[Tree] + mws []MiddlewareFunc handleMethodNotAllowed bool - - // Enable automatic redirection fallback when the current request does not match but another handler is found - // after cleaning up superfluous path elements (see CleanPath). E.g. /../foo/bar request does not match but /foo/bar would. - // The client is redirected with a http status code 301 for GET requests and 308 for all other methods. - redirectFixedPath bool - - // Enable automatic redirection fallback when the current request does not match but another handler is found - // with/without an additional trailing slash. E.g. /foo/bar/ request does not match but /foo/bar would match. - // The client is redirected with a http status code 301 for GET requests and 308 for all other methods. - redirectTrailingSlash bool - - // If enabled, the matched route will be accessible as a Handler parameter. - // Usage: p.Get(fox.RouteKey) - saveMatchedRoute bool + redirectFixedPath bool + redirectTrailingSlash bool } var _ http.Handler = (*Router)(nil) -// New returns a ready to use Router. +// New returns a ready to use instance of Fox router. func New(opts ...Option) *Router { r := new(Router) + + r.noRoute = NotFoundHandler() + r.noMethod = MethodNotAllowedHandler() + r.errRoute = RouteErrorHandler() + for _, opt := range opts { opt.apply(r) } + r.tree.Store(r.NewTree()) return r } @@ -89,41 +82,65 @@ func New(opts ...Option) *Router { // This api is EXPERIMENTAL and is likely to change in future release. func (fox *Router) NewTree() *Tree { tree := new(Tree) - tree.saveRoute = fox.saveMatchedRoute + tree.mws = fox.mws + // Pre instantiate nodes for common http verb nds := make([]*node, len(commonVerbs)) for i := range commonVerbs { nds[i] = new(node) nds[i].key = commonVerbs[i] + nds[i].paramChildIndex = -1 } tree.nodes.Store(&nds) - tree.p = sync.Pool{ + tree.ctx = sync.Pool{ New: func() any { - params := make(Params, 0, tree.maxParams.Load()) - return ¶ms + return tree.allocateContext() }, } return tree } -// Handler registers a new handler for the given method and path. This function return an error if the route +// Tree atomically loads and return the currently in-use routing tree. +// This API is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) Tree() *Tree { + return fox.tree.Load() +} + +// Swap atomically replaces the currently in-use routing tree with the provided new tree, and returns the previous tree. +// This API is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) Swap(new *Tree) (old *Tree) { + return fox.tree.Swap(new) +} + +// Handle registers a new handler for the given method and path. This function return an error if the route // is already registered or conflict with another. It's perfectly safe to add a new handler while the tree is in use // for serving requests. This function is safe for concurrent use by multiple goroutine. // To override an existing route, use Update. -func (fox *Router) Handler(method, path string, handler Handler) error { +func (fox *Router) Handle(method, path string, handler HandlerFunc) error { t := fox.Tree() t.Lock() defer t.Unlock() - return t.Handler(method, path, handler) + return t.Handle(method, path, handler) +} + +// MustHandle registers a new handler for the given method and path. This function is a convenience +// wrapper for the Handle function. It will panic if the route is already registered or conflicts +// with another route. It's perfectly safe to add a new handler while the tree is in use for serving +// requests. This function is safe for concurrent use by multiple goroutines. +// To override an existing route, use Update. +func (fox *Router) MustHandle(method, path string, handler HandlerFunc) { + if err := fox.Handle(method, path, handler); err != nil { + panic(err) + } } // Update override an existing handler for the given method and path. If the route does not exist, // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. This function is safe for concurrent use by multiple goroutine. -// To add new handler, use Handler method. -func (fox *Router) Update(method, path string, handler Handler) error { +// To add new handler, use Handle method. +func (fox *Router) Update(method, path string, handler HandlerFunc) error { t := fox.Tree() t.Lock() defer t.Unlock() @@ -140,53 +157,20 @@ func (fox *Router) Remove(method, path string) error { return t.Remove(method, path) } -// Tree atomically loads and return the currently in-use routing tree. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Tree() *Tree { - return fox.tree.Load() -} - -// Swap atomically replaces the currently in-use routing tree with the provided new tree, and returns the previous tree. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Swap(new *Tree) (old *Tree) { - return fox.tree.Swap(new) -} - -// Use atomically replaces the currently in-use routing tree with the provided new tree. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Use(new *Tree) { - fox.tree.Store(new) -} - -// Lookup allow to do manual lookup of a route and return the matched handler along with parsed params and -// trailing slash redirect recommendation. Note that you should always free Params if NOT nil by calling -// params.Free(t). If lazy is set to true, route params are not parsed. This function is safe for concurrent use -// by multiple goroutine and while mutation on Tree are ongoing. -func Lookup(t *Tree, method, path string, lazy bool) (handler Handler, params *Params, tsr bool) { - nds := t.load() - index := findRootNode(method, nds) - if index < 0 { - return nil, nil, false - } - - n, ps, tsr := t.lookup(nds[index], path, lazy) - if n != nil { - return n.handler, ps, tsr - } - return nil, nil, tsr -} - // Has allows to check if the given method and path exactly match a registered route. This function is safe for // concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This api is EXPERIMENTAL and is likely to change in future release. func Has(t *Tree, method, path string) bool { - nds := t.load() + nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { return false } - n, _, _ := t.lookup(nds[index], path, true) + c := t.ctx.Get().(*context) + c.resetNil() + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() return n != nil && n.path == path } @@ -194,12 +178,16 @@ func Has(t *Tree, method, path string) bool { // This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This api is EXPERIMENTAL and is likely to change in future release. func Reverse(t *Tree, method, path string) string { - nds := t.load() + nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { return "" } - n, _, _ := t.lookup(nds[index], path, true) + + c := t.ctx.Get().(*context) + c.resetNil() + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() if n == nil { return "" } @@ -211,23 +199,23 @@ func Reverse(t *Tree, method, path string) string { var SkipMethod = errors.New("skip method") // WalkFunc is the type of the function called by Walk to visit each registered routes. -type WalkFunc func(method, path string, handler Handler) error +type WalkFunc func(method, path string, handler HandlerFunc) error // Walk allow to walk over all registered route in lexicographical order. If the function // return the special value SkipMethod, Walk skips the current method. This function is // safe for concurrent use by multiple goroutine and while mutation are ongoing. // This api is EXPERIMENTAL and is likely to change in future release. func Walk(tree *Tree, fn WalkFunc) error { - nds := tree.load() -NEXT: + nds := *tree.nodes.Load() +Next: for i := range nds { method := nds[i].key it := newRawIterator(nds[i]) for it.hasNext() { - err := fn(method, it.fullPath(), it.node().handler) + err := fn(method, it.path, it.current.handler) if err != nil { if errors.Is(err, SkipMethod) { - continue NEXT + continue Next } return err } @@ -237,35 +225,72 @@ NEXT: return nil } -func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if fox.panicHandler != nil { - defer fox.recover(w, r) +// NotFoundHandler returns a simple HandlerFunc that replies to each request +// with a “404 page not found” reply. +func NotFoundHandler() HandlerFunc { + return func(c Context) error { + http.Error(c.Writer(), "404 page not found", http.StatusNotFound) + return nil + } +} + +// MethodNotAllowedHandler returns a simple HandlerFunc that replies to each request +// with a “405 Method Not Allowed” reply. +func MethodNotAllowedHandler() HandlerFunc { + return func(c Context) error { + http.Error(c.Writer(), http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return nil } +} + +// RouteErrorHandler returns an ErrorHandlerFunc that handle HandlerFunc error. +func RouteErrorHandler() ErrorHandlerFunc { + return func(c Context, err error) { + if !c.Writer().Written() { + if e, ok := err.(HTTPError); ok { + http.Error(c.Writer(), e.Error(), e.Code) + return + } + http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + } +} + +func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { var ( - n *node - params *Params - tsr bool + n *node + tsr bool ) - tree := fox.Tree() - nds := tree.load() + tree := fox.tree.Load() + c := tree.ctx.Get().(*context) + c.reset(fox, w, r) + + nds := *tree.nodes.Load() index := findRootNode(r.Method, nds) if index < 0 { - goto NO_METHOD_FALLBACK + goto NoMethodFallback } - n, params, tsr = tree.lookup(nds[index], r.URL.Path, false) + n, tsr = tree.lookup(nds[index], r.URL.Path, c.params, c.skipNds, false) if n != nil { - if params != nil { - n.handler.ServeHTTP(w, r, *params) - params.Free(tree) - return + c.path = n.path + err := n.handler(c) + if err != nil { + fox.errRoute(c, err) + } + // Put back the context, if not extended more than max params or max depth, allowing + // the slice to naturally grow within the constraint. + if cap(*c.params) <= int(tree.maxParams.Load()) && cap(*c.skipNds) <= int(tree.maxDepth.Load()) { + c.tree.ctx.Put(c) } - n.handler.ServeHTTP(w, r, nil) return } + // Reset params as it may have recorded wildcard segment + *c.params = (*c.params)[:0] + if r.Method != http.MethodConnect && r.URL.Path != "/" { code := http.StatusMovedPermanently @@ -277,32 +302,35 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if tsr && fox.redirectTrailingSlash { r.URL.Path = fixTrailingSlash(r.URL.Path) http.Redirect(w, r, r.URL.String(), code) + c.Close() return } if fox.redirectFixedPath { cleanedPath := CleanPath(r.URL.Path) - n, _, tsr := tree.lookup(nds[index], cleanedPath, true) + n, tsr := tree.lookup(nds[index], cleanedPath, c.params, c.skipNds, true) if n != nil { r.URL.Path = cleanedPath http.Redirect(w, r, r.URL.String(), code) + c.Close() return } if tsr && fox.redirectTrailingSlash { r.URL.Path = fixTrailingSlash(cleanedPath) http.Redirect(w, r, r.URL.String(), code) + c.Close() return } } } -NO_METHOD_FALLBACK: +NoMethodFallback: if fox.handleMethodNotAllowed { var sb strings.Builder for i := 0; i < len(nds); i++ { if nds[i].key != r.Method { - if n, _, _ := tree.lookup(nds[i], r.URL.Path, true); n != nil { + if n, _ := tree.lookup(nds[i], r.URL.Path, c.params, c.skipNds, true); n != nil { if sb.Len() > 0 { sb.WriteString(", ") } @@ -313,29 +341,20 @@ NO_METHOD_FALLBACK: allowed := sb.String() if allowed != "" { w.Header().Set("Allow", allowed) - if fox.methodNotAllowed != nil { - fox.methodNotAllowed.ServeHTTP(w, r) - return + err := fox.noMethod(c) + if err != nil { + fox.errRoute(c, err) } - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + c.Close() return } } - if fox.notFound != nil { - fox.notFound.ServeHTTP(w, r) - return - } - http.NotFound(w, r) -} - -func (fox *Router) recover(w http.ResponseWriter, r *http.Request) { - if val := recover(); val != nil { - if abortErr, ok := val.(error); ok && errors.Is(abortErr, http.ErrAbortHandler) { - panic(abortErr) - } - fox.panicHandler(w, r, val) + err := fox.noRoute(c) + if err != nil { + fox.errRoute(c, err) } + c.Close() } type resultType int @@ -386,6 +405,7 @@ type searchResult struct { path string charsMatched int charsMatchedInNodeFound int + depth uint32 } func min(a, b int) int { @@ -436,7 +456,7 @@ const ( func parseRoute(path string) (string, string, int, error) { if !strings.HasPrefix(path, "/") { - return "", "", -1, fmt.Errorf("path must start with '/': %w", ErrInvalidRoute) + return "", "", -1, fmt.Errorf("%w: path must start with '/'", ErrInvalidRoute) } state := stateDefault @@ -511,6 +531,15 @@ func parseRoute(path string) (string, string, int, error) { func getRouteConflict(n *node) []string { routes := make([]string, 0) + + if n.isCatchAll() { + routes = append(routes, n.path) + return routes + } + + if n.paramChildIndex >= 0 { + n = n.children[n.paramChildIndex].Load() + } it := newRawIterator(n) for it.hasNext() { routes = append(routes, it.current.path) @@ -526,3 +555,11 @@ func isRemovable(method string) bool { } return true } + +func applyMiddleware(mws []MiddlewareFunc, h HandlerFunc) HandlerFunc { + m := h + for i := len(mws) - 1; i >= 0; i-- { + m = mws[i](m) + } + return m +} diff --git a/router_test.go b/fox_test.go similarity index 64% rename from router_test.go rename to fox_test.go index 56fe787..ab7013d 100644 --- a/router_test.go +++ b/fox_test.go @@ -1,6 +1,11 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( + "errors" "fmt" fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/assert" @@ -17,7 +22,9 @@ import ( "time" ) -var emptyHandler = HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) {}) +var emptyHandler = HandlerFunc(func(c Context) error { return nil }) +var pathHandler = HandlerFunc(func(c Context) error { return c.String(200, c.Request().URL.Path) }) +var routeHandler = HandlerFunc(func(c Context) error { return c.String(200, c.Path()) }) type mockResponseWriter struct{} @@ -40,6 +47,12 @@ type route struct { path string } +var overlappingRoutes = []route{ + {"GET", "/foo/abc/id:{id}/xyz"}, + {"GET", "/foo/{name}/id:{id}/{name}"}, + {"GET", "/foo/{name}/id:{id}/xyz"}, +} + // From https://github.com/julienschmidt/go-http-routing-benchmark var staticRoutes = []route{ {"GET", "/"}, @@ -470,39 +483,36 @@ func benchRouteParallel(b *testing.B, router http.Handler, rte route) { func BenchmarkStaticAll(b *testing.B) { r := New() for _, route := range staticRoutes { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(w http.ResponseWriter, r *http.Request, p Params) {}))) + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } benchRoutes(b, r, staticRoutes) } -func BenchmarkLookup(b *testing.B) { +func BenchmarkGithubParamsAll(b *testing.B) { r := New() - for _, route := range staticRoutes { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(w http.ResponseWriter, r *http.Request, p Params) {}))) + for _, route := range githubAPI { + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } + req := httptest.NewRequest("GET", "/repos/sylvain/fox/hooks/1500", nil) + w := new(mockResponseWriter) + b.ReportAllocs() b.ResetTimer() - tree := r.Tree() for i := 0; i < b.N; i++ { - for _, route := range staticRoutes { - _, p, _ := Lookup(tree, route.method, route.path, false) - if p != nil { - p.Free(tree) - } - } + r.ServeHTTP(w, req) } } -func BenchmarkGithubParamsAll(b *testing.B) { +func BenchmarkOverlappingRoute(b *testing.B) { r := New() - for _, route := range githubAPI { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(w http.ResponseWriter, r *http.Request, p Params) {}))) + for _, route := range overlappingRoutes { + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } - req := httptest.NewRequest("GET", "/repos/sylvain/fox/hooks/1500", nil) + req := httptest.NewRequest("GET", "/foo/abc/id:123/xy", nil) w := new(mockResponseWriter) b.ReportAllocs() @@ -516,14 +526,14 @@ func BenchmarkGithubParamsAll(b *testing.B) { func BenchmarkStaticParallel(b *testing.B) { r := New() for _, route := range staticRoutes { - require.NoError(b, r.Tree().Handler(route.method, route.path, HandlerFunc(func(_ http.ResponseWriter, _ *http.Request, _ Params) {}))) + require.NoError(b, r.Tree().Handle(route.method, route.path, emptyHandler)) } benchRouteParallel(b, r, route{"GET", "/progs/image_package4.out"}) } func BenchmarkCatchAll(b *testing.B) { r := New() - require.NoError(b, r.Tree().Handler(http.MethodGet, "/something/*{args}", HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}))) + require.NoError(b, r.Tree().Handle(http.MethodGet, "/something/*{args}", emptyHandler)) w := new(mockResponseWriter) req := httptest.NewRequest("GET", "/something/awesome", nil) @@ -537,7 +547,7 @@ func BenchmarkCatchAll(b *testing.B) { func BenchmarkCatchAllParallel(b *testing.B) { r := New() - require.NoError(b, r.Tree().Handler(http.MethodGet, "/something/*{args}", HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}))) + require.NoError(b, r.Tree().Handle(http.MethodGet, "/something/*{args}", emptyHandler)) w := new(mockResponseWriter) req := httptest.NewRequest("GET", "/something/awesome", nil) @@ -553,10 +563,9 @@ func BenchmarkCatchAllParallel(b *testing.B) { func TestStaticRoute(t *testing.T) { r := New() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) { _, _ = w.Write([]byte(r.URL.Path)) }) for _, route := range staticRoutes { - require.NoError(t, r.Tree().Handler(route.method, route.path, h)) + require.NoError(t, r.Tree().Handle(route.method, route.path, pathHandler)) } for _, route := range staticRoutes { @@ -568,11 +577,26 @@ func TestStaticRoute(t *testing.T) { } } +func TestStaticRouteMalloc(t *testing.T) { + r := New() + + for _, route := range staticRoutes { + require.NoError(t, r.Tree().Handle(route.method, route.path, emptyHandler)) + } + + for _, route := range staticRoutes { + req := httptest.NewRequest(route.method, route.path, nil) + w := httptest.NewRecorder() + allocs := testing.AllocsPerRun(100, func() { r.ServeHTTP(w, req) }) + assert.Equal(t, float64(0), allocs) + } +} + func TestParamsRoute(t *testing.T) { - rx := regexp.MustCompile("({|\\*{)[A-z_]+[}]") - r := New(WithSaveMatchedRoute(true)) - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - matches := rx.FindAllString(r.URL.Path, -1) + rx := regexp.MustCompile("({|\\*{)[A-z]+[}]") + r := New() + h := func(c Context) error { + matches := rx.FindAllString(c.Request().URL.Path, -1) for _, match := range matches { var key string if strings.HasPrefix(match, "*") { @@ -581,13 +605,13 @@ func TestParamsRoute(t *testing.T) { key = match[1 : len(match)-1] } value := match - assert.Equal(t, value, params.Get(key)) + assert.Equal(t, value, c.Param(key)) } - assert.Equal(t, r.URL.Path, params.Get(RouteKey)) - _, _ = w.Write([]byte(r.URL.Path)) - }) + assert.Equal(t, c.Request().URL.Path, c.Path()) + return c.String(200, c.Request().URL.Path) + } for _, route := range githubAPI { - require.NoError(t, r.Tree().Handler(route.method, route.path, h)) + require.NoError(t, r.Tree().Handle(route.method, route.path, h)) } for _, route := range githubAPI { req := httptest.NewRequest(route.method, route.path, nil) @@ -598,9 +622,34 @@ func TestParamsRoute(t *testing.T) { } } +func TestParamsRouteMalloc(t *testing.T) { + r := New() + for _, route := range githubAPI { + require.NoError(t, r.Tree().Handle(route.method, route.path, emptyHandler)) + } + for _, route := range githubAPI { + req := httptest.NewRequest(route.method, route.path, nil) + w := httptest.NewRecorder() + allocs := testing.AllocsPerRun(100, func() { r.ServeHTTP(w, req) }) + assert.Equal(t, float64(0), allocs) + } +} + +func TestOverlappingRouteMalloc(t *testing.T) { + r := New() + for _, route := range overlappingRoutes { + require.NoError(t, r.Tree().Handle(route.method, route.path, emptyHandler)) + } + for _, route := range overlappingRoutes { + req := httptest.NewRequest(route.method, route.path, nil) + w := httptest.NewRecorder() + allocs := testing.AllocsPerRun(100, func() { r.ServeHTTP(w, req) }) + assert.Equal(t, float64(0), allocs) + } +} + func TestRouterWildcard(t *testing.T) { r := New() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { _, _ = w.Write([]byte(r.URL.Path)) }) routes := []struct { path string @@ -613,7 +662,7 @@ func TestRouterWildcard(t *testing.T) { } for _, route := range routes { - require.NoError(t, r.Tree().Handler(http.MethodGet, route.path, h)) + require.NoError(t, r.Tree().Handle(http.MethodGet, route.path, pathHandler)) } for _, route := range routes { @@ -644,12 +693,13 @@ func TestRouteWithParams(t *testing.T) { "/info/{user}/project/{project}", } for _, rte := range routes { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) } - nds := tree.load() + nds := *tree.nodes.Load() for _, rte := range routes { - n, _, _ := tree.lookup(nds[0], rte, false) + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], rte, c.params, c.skipNds, false) require.NotNil(t, n) assert.Equal(t, rte, n.path) } @@ -680,284 +730,381 @@ func TestRoutParamEmptySegment(t *testing.T) { } for _, tc := range cases { - require.NoError(t, tree.Handler(http.MethodGet, tc.route, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, tc.route, emptyHandler)) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - nds := tree.load() - n, ps, tsr := tree.lookup(nds[0], tc.path, false) + nds := *tree.nodes.Load() + c := newTestContextTree(tree) + n, tsr := tree.lookup(nds[0], tc.path, c.params, c.skipNds, false) assert.Nil(t, n) - assert.Nil(t, ps) + assert.Empty(t, c.Params()) assert.False(t, tsr) }) } } -func TestInsertWildcardConflict(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) +func TestOverlappingRoute(t *testing.T) { + r := New() cases := []struct { - name string - routes []struct { - wantErr error - path string - wantMatch []string - wildcard bool - } + name string + path string + routes []string + wantMatch string + wantParams Params }{ { - name: "key mid edge conflicts", - routes: []struct { - wantErr error - path string - wantMatch []string - wildcard bool - }{ - {path: "/foo/bar", wildcard: false, wantErr: nil, wantMatch: nil}, - {path: "/foo/baz", wildcard: false, wantErr: nil, wantMatch: nil}, - {path: "/foo/", wildcard: true, wantErr: ErrRouteConflict, wantMatch: []string{"/foo/bar", "/foo/baz"}}, - {path: "/foo/bar/baz/", wildcard: true, wantErr: nil}, - {path: "/foo/bar/", wildcard: true, wantErr: ErrRouteConflict, wantMatch: []string{"/foo/bar/baz/*{args}"}}, + name: "basic test most specific", + path: "/products/new", + routes: []string{ + "/products/{id}", + "/products/new", }, + wantMatch: "/products/new", }, { - name: "incomplete match to the end of edge conflict", - routes: []struct { - wantErr error - path string - wantMatch []string - wildcard bool - }{ - {path: "/foo/", wildcard: true, wantErr: nil, wantMatch: nil}, - {path: "/foo/bar", wildcard: false, wantErr: ErrRouteConflict, wantMatch: []string{"/foo/*{args}"}}, - {path: "/fuzz/baz/bar/", wildcard: true, wantErr: nil, wantMatch: nil}, - {path: "/fuzz/baz/bar/foo", wildcard: false, wantErr: ErrRouteConflict, wantMatch: []string{"/fuzz/baz/bar/*{args}"}}, + name: "basic test less specific", + path: "/products/123", + routes: []string{ + "/products/{id}", + "/products/new", }, + wantMatch: "/products/{id}", + wantParams: Params{{Key: "id", Value: "123"}}, }, { - name: "exact match conflict", - routes: []struct { - wantErr error - path string - wantMatch []string - wildcard bool - }{ - {path: "/foo/1", wildcard: false, wantErr: nil, wantMatch: nil}, - {path: "/foo/2", wildcard: false, wantErr: nil, wantMatch: nil}, - {path: "/foo/", wildcard: true, wantErr: ErrRouteConflict, wantMatch: []string{"/foo/1", "/foo/2"}}, - {path: "/foo/1/baz/1", wildcard: false, wantErr: nil, wantMatch: nil}, - {path: "/foo/1/baz/2", wildcard: false, wantErr: nil, wantMatch: nil}, - {path: "/foo/1/baz/", wildcard: true, wantErr: ErrRouteConflict, wantMatch: []string{"/foo/1/baz/1", "/foo/1/baz/2"}}, + name: "ieof+backtrack to {id} wildcard while deleting {a}", + path: "/base/val1/123/new/barr", + routes: []string{ + "/{base}/val1/{id}", + "/{base}/val1/123/{a}/bar", + "/{base}/val1/{id}/new/{name}", + "/{base}/val2", }, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - tree := New().Tree() - for _, rte := range tc.routes { - var catchAllKey string - if rte.wildcard { - catchAllKey = "args" - } - err := tree.insert(http.MethodGet, rte.path, catchAllKey, 0, h) - assert.ErrorIs(t, err, rte.wantErr) - if cErr, ok := err.(*RouteConflictError); ok { - assert.Equal(t, rte.wantMatch, cErr.Matched) - } - } - }) - } -} - -func TestInsertParamsConflict(t *testing.T) { - cases := []struct { - name string - routes []struct { - path string - wildcard string - wantErr error - wantMatching []string - } - }{ - { - name: "KEY_END_MID_EDGE split right before param", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/", wildcard: "", wantErr: nil, wantMatching: nil}, + wantMatch: "/{base}/val1/{id}/new/{name}", + wantParams: Params{ + { + Key: "base", + Value: "base", + }, + { + Key: "id", + Value: "123", + }, + { + Key: "name", + Value: "barr", + }, }, }, { - name: "KEY_END_MID_EDGE split param at the start of the path segment", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/{f}", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/{foo}"}}, + name: "kme+backtrack to {id} wildcard while deleting {a}", + path: "/base/val1/123/new/ba", + routes: []string{ + "/{base}/val1/{id}", + "/{base}/val1/123/{a}/bar", + "/{base}/val1/{id}/new/{name}", + "/{base}/val2", }, - }, - { - name: "KEY_END_MID_EDGE split a char before the param", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test", wildcard: "", wantErr: nil, wantMatching: nil}, + wantMatch: "/{base}/val1/{id}/new/{name}", + wantParams: Params{ + { + Key: "base", + Value: "base", + }, + { + Key: "id", + Value: "123", + }, + { + Key: "name", + Value: "ba", + }, }, }, { - name: "KEY_END_MID_EDGE split right before inflight param", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/abc{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/abc", wildcard: "", wantErr: nil, wantMatching: nil}, + name: "ime+backtrack to {id} wildcard while deleting {a}", + path: "/base/val1/123/new/bx", + routes: []string{ + "/{base}/val1/{id}", + "/{base}/val1/123/{a}/bar", + "/{base}/val1/{id}/new/{name}", + "/{base}/val2", + }, + wantMatch: "/{base}/val1/{id}/new/{name}", + wantParams: Params{ + { + Key: "base", + Value: "base", + }, + { + Key: "id", + Value: "123", + }, + { + Key: "name", + Value: "bx", + }, }, }, { - name: "KEY_END_MID_EDGE split param in flight", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/abc{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/abc{f}", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/abc{foo}"}}, + name: "backtrack to catch while deleting {a}, {id} and {name}", + path: "/base/val1/123/new/bar/", + routes: []string{ + "/{base}/val1/{id}", + "/{base}/val1/123/{a}/bar", + "/{base}/val1/{id}/new/{name}", + "/{base}/val*{all}", + }, + wantMatch: "/{base}/val*{all}", + wantParams: Params{ + { + Key: "base", + Value: "base", + }, + { + Key: "all", + Value: "1/123/new/bar/", + }, }, }, { - name: "KEY_END_MID_EDGE param with child starting with separator", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/{foo}/star", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, + name: "notleaf+backtrack to catch while deleting {a}, {id}", + path: "/base/val1/123/new", + routes: []string{ + "/{base}/val1/123/{a}/baz", + "/{base}/val1/123/{a}/bar", + "/{base}/val1/{id}/new/{name}", + "/{base}/val*{all}", + }, + wantMatch: "/{base}/val*{all}", + wantParams: Params{ + { + Key: "base", + Value: "base", + }, + { + Key: "all", + Value: "1/123/new", + }, }, }, { - name: "KEY_END_MID_EDGE inflight param with child starting with separator", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/abc{foo}/star", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/abc{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, + name: "multi node most specific", + path: "/foo/1/2/3/bar", + routes: []string{ + "/foo/{ab}", + "/foo/{ab}/{bc}", + "/foo/{ab}/{bc}/{de}", + "/foo/{ab}/{bc}/{de}/bar", + "/foo/{ab}/{bc}/{de}/{fg}", + }, + wantMatch: "/foo/{ab}/{bc}/{de}/bar", + wantParams: Params{ + { + Key: "ab", + Value: "1", + }, + { + Key: "bc", + Value: "2", + }, + { + Key: "de", + Value: "3", + }, }, }, { - name: "INCOMPLETE_MATCH_TO_MIDDLE_OF_EDGE split existing node right before param", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/a", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/{foo}"}}, + name: "multi node less specific", + path: "/foo/1/2/3/john", + routes: []string{ + "/foo/{ab}", + "/foo/{ab}/{bc}", + "/foo/{ab}/{bc}/{de}", + "/foo/{ab}/{bc}/{de}/bar", + "/foo/{ab}/{bc}/{de}/{fg}", + }, + wantMatch: "/foo/{ab}/{bc}/{de}/{fg}", + wantParams: Params{ + { + Key: "ab", + Value: "1", + }, + { + Key: "bc", + Value: "2", + }, + { + Key: "de", + Value: "3", + }, + { + Key: "fg", + Value: "john", + }, }, }, { - name: "INCOMPLETE_MATCH_TO_MIDDLE_OF_EDGE split new node right before param", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test{foo}", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/{foo}"}}, + name: "backtrack on empty mid key parameter", + path: "/foo/abc/bar", + routes: []string{ + "/foo/abc{id}/bar", + "/foo/{name}/bar", + }, + wantMatch: "/foo/{name}/bar", + wantParams: Params{ + { + Key: "name", + Value: "abc", + }, }, }, { - name: "INCOMPLETE_MATCH_TO_MIDDLE_OF_EDGE split existing node after param", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/{fx}", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/{foo}"}}, + name: "most specific wildcard between catch all", + path: "/foo/123", + routes: []string{ + "/foo/{id}", + "/foo/a*{args}", + "/foo*{args}", + }, + wantMatch: "/foo/{id}", + wantParams: Params{ + { + Key: "id", + Value: "123", + }, }, }, { - name: "INCOMPLETE_MATCH_TO_MIDDLE_OF_EDGE split existing node right before inflight param", - routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string - }{ - {path: "/test/abc{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/abcd", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/abc{foo}"}}, + name: "most specific catch all with param", + path: "/foo/abc", + routes: []string{ + "/foo/{id}", + "/foo/a*{args}", + "/foo*{args}", + }, + wantMatch: "/foo/a*{args}", + wantParams: Params{ + { + Key: "args", + Value: "bc", + }, }, }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tree := r.NewTree() + for _, rte := range tc.routes { + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + } + nds := *tree.nodes.Load() + + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], tc.path, c.params, c.skipNds, false) + require.NotNil(t, n) + require.NotNil(t, n.handler) + + assert.Equal(t, tc.wantMatch, n.path) + if len(tc.wantParams) == 0 { + assert.Empty(t, c.Params()) + } else { + assert.Equal(t, tc.wantParams, c.Params()) + } + + // Test with lazy + c = newTestContextTree(tree) + n, _ = tree.lookup(nds[0], tc.path, c.params, c.skipNds, true) + require.NotNil(t, n) + require.NotNil(t, n.handler) + assert.Empty(t, c.Params()) + assert.Equal(t, tc.wantMatch, n.path) + }) + } +} + +func TestInsertConflict(t *testing.T) { + cases := []struct { + name string + routes []struct { + wantErr error + path string + wantMatch []string + } + }{ { - name: "INCOMPLETE_MATCH_TO_MIDDLE_OF_EDGE split new node right before inflight param", + name: "exact match conflict", routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string + wantErr error + path string + wantMatch []string }{ - {path: "/test/abc{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/ab{foo}", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/abc{foo}"}}, + {path: "/john/*{x}", wantErr: nil, wantMatch: nil}, + {path: "/john/*{y}", wantErr: ErrRouteConflict, wantMatch: []string{"/john/*{x}"}}, + {path: "/john/", wantErr: ErrRouteExist, wantMatch: nil}, + {path: "/foo/baz", wantErr: nil, wantMatch: nil}, + {path: "/foo/bar", wantErr: nil, wantMatch: nil}, + {path: "/foo/{id}", wantErr: nil, wantMatch: nil}, + {path: "/foo/*{args}", wantErr: ErrRouteConflict, wantMatch: []string{"/foo/{id}"}}, + {path: "/avengers/ironman/{power}", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{id}/bar", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{id}/foo", wantErr: nil, wantMatch: nil}, + {path: "/avengers/*{args}", wantErr: ErrRouteConflict, wantMatch: []string{"/avengers/{id}/bar", "/avengers/{id}/foo"}}, + {path: "/fox/", wantErr: nil, wantMatch: nil}, + {path: "/fox/*{args}", wantErr: ErrRouteExist, wantMatch: nil}, }, }, { - name: "INCOMPLETE_MATCH_TO_END_OF_EDGE add new node right after param without slash", + name: "incomplete match to end of edge conflicts", routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string + wantErr error + path string + wantMatch []string }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/{foox}", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/{foo}"}}, + {path: "/foo/bar", wantErr: nil, wantMatch: nil}, + {path: "/foo/baz", wantErr: nil, wantMatch: nil}, + {path: "/foo/*{args}", wantErr: nil, wantMatch: nil}, + {path: "/foo/{id}", wantErr: ErrRouteConflict, wantMatch: []string{"/foo/*{args}"}}, }, }, { - name: "INCOMPLETE_MATCH_TO_END_OF_EDGE add new node right after inflight param without slash", + name: "key match mid-edge conflict", routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string + wantErr error + path string + wantMatch []string }{ - {path: "/test/abc{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/abc{foox}", wildcard: "", wantErr: ErrRouteConflict, wantMatching: []string{"/test/abc{foo}"}}, + {path: "/foo/{id}", wantErr: nil, wantMatch: nil}, + {path: "/foo/*{args}", wantErr: ErrRouteConflict, wantMatch: []string{"/foo/{id}"}}, + {path: "/foo/a*{args}", wantErr: nil, wantMatch: nil}, + {path: "/foo*{args}", wantErr: nil, wantMatch: nil}, + {path: "/john{doe}", wantErr: nil, wantMatch: nil}, + {path: "/john*{doe}", wantErr: ErrRouteConflict, wantMatch: []string{"/john{doe}"}}, + {path: "/john/{doe}", wantErr: nil, wantMatch: nil}, + {path: "/joh{doe}", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{id}/foo", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{id}/bar", wantErr: nil, wantMatch: nil}, + {path: "/avengers/*{args}", wantErr: ErrRouteConflict, wantMatch: []string{"/avengers/{id}/bar", "/avengers/{id}/foo"}}, }, }, { - name: "INCOMPLETE_MATCH_TO_END_OF_EDGE add new static node right after param", + name: "incomplete match to middle of edge", routes: []struct { - path string - wildcard string - wantErr error - wantMatching []string + wantErr error + path string + wantMatch []string }{ - {path: "/test/{foo}", wildcard: "", wantErr: nil, wantMatching: nil}, - {path: "/test/{foo}/ba", wildcard: "", wantErr: nil, wantMatching: nil}, + {path: "/foo/{id}", wantErr: nil, wantMatch: nil}, + {path: "/foo/{abc}", wantErr: ErrRouteConflict, wantMatch: []string{"/foo/{id}"}}, + {path: "/foo{id}", wantErr: nil, wantMatch: nil}, + {path: "/foo/a{id}", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{id}/bar", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{id}/baz", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{id}", wantErr: nil, wantMatch: nil}, + {path: "/avengers/{abc}", wantErr: ErrRouteConflict, wantMatch: []string{"/avengers/{id}", "/avengers/{id}/bar", "/avengers/{id}/baz"}}, }, }, } @@ -966,66 +1113,62 @@ func TestInsertParamsConflict(t *testing.T) { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() for _, rte := range tc.routes { - err := tree.insert(http.MethodGet, rte.path, rte.wildcard, 0, emptyHandler) - if rte.wantErr != nil { - assert.ErrorIs(t, err, rte.wantErr) - if cErr, ok := err.(*RouteConflictError); ok { - assert.Equal(t, rte.wantMatching, cErr.Matched) - } + err := tree.Handle(http.MethodGet, rte.path, emptyHandler) + assert.ErrorIs(t, err, rte.wantErr) + if cErr, ok := err.(*RouteConflictError); ok { + assert.Equal(t, rte.wantMatch, cErr.Matched) } } }) } } -func TestSwapWildcardConflict(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) +func TestUpdateConflict(t *testing.T) { cases := []struct { - wantErr error - name string - path string - routes []struct { - path string - wildcard bool - } + name string + routes []string + update string + wantErr error wantMatch []string - wildcard string }{ { - name: "replace existing static node with wildcard", - routes: []struct { - path string - wildcard bool - }{ - {path: "/foo/bar", wildcard: false}, - {path: "/foo/baz", wildcard: false}, - {path: "/foo/", wildcard: false}, - }, - path: "/foo/", - wildcard: "args", + name: "wildcard parameter route not registered", + routes: []string{"/foo/{bar}"}, + update: "/foo/{baz}", + wantErr: ErrRouteNotFound, + }, + { + name: "wildcard catch all route not registered", + routes: []string{"/foo/{bar}"}, + update: "/foo/*{baz}", + wantErr: ErrRouteNotFound, + }, + { + name: "route match but not a leaf", + routes: []string{"/foo/bar/baz"}, + update: "/foo/bar", + wantErr: ErrRouteNotFound, + }, + { + name: "wildcard have different name", + routes: []string{"/foo/bar", "/foo/*{args}"}, + update: "/foo/*{all}", wantErr: ErrRouteConflict, - wantMatch: []string{"/foo/bar", "/foo/baz"}, + wantMatch: []string{"/foo/*{args}"}, }, { - name: "replace existing wildcard node with static", - routes: []struct { - path string - wildcard bool - }{ - {path: "/foo/", wildcard: true}, - }, - path: "/foo/", + name: "replacing non wildcard by wildcard", + routes: []string{"/foo/bar", "/foo/"}, + update: "/foo/*{all}", + wantErr: ErrRouteConflict, + wantMatch: []string{"/foo/"}, }, { - name: "replace existing wildcard node with another wildcard", - routes: []struct { - path string - wildcard bool - }{ - {path: "/foo/", wildcard: true}, - }, - path: "/foo/", - wildcard: "new", + name: "replacing wildcard by non wildcard", + routes: []string{"/foo/bar", "/foo/*{args}"}, + update: "/foo/", + wantErr: ErrRouteConflict, + wantMatch: []string{"/foo/*{args}"}, }, } @@ -1033,13 +1176,9 @@ func TestSwapWildcardConflict(t *testing.T) { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() for _, rte := range tc.routes { - var catchAllKey string - if rte.wildcard { - catchAllKey = "args" - } - require.NoError(t, tree.insert(http.MethodGet, rte.path, catchAllKey, 0, h)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) } - err := tree.update(http.MethodGet, tc.path, tc.wildcard, h) + err := tree.Update(http.MethodGet, tc.update, emptyHandler) assert.ErrorIs(t, err, tc.wantErr) if cErr, ok := err.(*RouteConflictError); ok { assert.Equal(t, tc.wantMatch, cErr.Matched) @@ -1049,63 +1188,55 @@ func TestSwapWildcardConflict(t *testing.T) { } func TestUpdateRoute(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(r.URL.Path)) - }) - cases := []struct { - newHandler Handler - name string - path string - newPath string - newWildcardKey string + name string + routes []string + update string }{ { - name: "update wildcard with another wildcard", - path: "/foo/bar/*{args}", - newPath: "/foo/bar/", - newWildcardKey: "*{new}", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(params.Get(RouteKey))) - }), + name: "replacing ending static node", + routes: []string{"/foo/", "/foo/bar", "/foo/baz"}, + update: "/foo/bar", }, { - name: "update wildcard with non wildcard", - path: "/foo/bar/*{args}", - newPath: "/foo/bar/", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(r.URL.Path)) - }), + name: "replacing middle static node", + routes: []string{"/foo/", "/foo/bar", "/foo/baz"}, + update: "/foo/", }, { - name: "update non wildcard with wildcard", - path: "/foo/bar/", - newPath: "/foo/bar/", - newWildcardKey: "*{foo}", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(params.Get(RouteKey))) - }), + name: "replacing ending wildcard node", + routes: []string{"/foo/", "/foo/bar", "/foo/{baz}"}, + update: "/foo/{baz}", }, { - name: "update non wildcard with non wildcard", - path: "/foo/bar", - newPath: "/foo/bar", - newHandler: HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - w.Write([]byte(r.URL.Path)) - }), + name: "replacing ending inflight wildcard node", + routes: []string{"/foo/", "/foo/bar_xyz", "/foo/bar_{baz}"}, + update: "/foo/bar_{baz}", + }, + { + name: "replacing middle wildcard node", + routes: []string{"/foo/{bar}", "/foo/{bar}/baz", "/foo/{bar}/xyz"}, + update: "/foo/{bar}", + }, + { + name: "replacing middle inflight wildcard node", + routes: []string{"/foo/id:{bar}", "/foo/id:{bar}/baz", "/foo/id:{bar}/xyz"}, + update: "/foo/id:{bar}", + }, + { + name: "replacing catch all node", + routes: []string{"/foo/*{bar}", "/foo", "/foo/bar"}, + update: "/foo/*{bar}", }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - r := New(WithSaveMatchedRoute(true)) - require.NoError(t, r.Tree().Handler(http.MethodGet, tc.path, h)) - require.NoError(t, r.Tree().Update(http.MethodGet, tc.newPath+tc.newWildcardKey, tc.newHandler)) - req := httptest.NewRequest(http.MethodGet, tc.newPath, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - require.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, tc.newPath+tc.newWildcardKey, w.Body.String()) + tree := New().Tree() + for _, rte := range tc.routes { + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + } + assert.NoError(t, tree.Update(http.MethodGet, tc.update, emptyHandler)) }) } } @@ -1250,8 +1381,6 @@ func TestParseRoute(t *testing.T) { } func TestTree_LookupTsr(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) - cases := []struct { name string path string @@ -1295,16 +1424,16 @@ func TestTree_LookupTsr(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() - require.NoError(t, tree.insert(http.MethodGet, tc.path, "", 0, h)) - nds := tree.load() - _, _, got := tree.lookup(nds[0], tc.key, true) + require.NoError(t, tree.insert(http.MethodGet, tc.path, "", 0, emptyHandler)) + nds := *tree.nodes.Load() + c := newTestContextTree(tree) + _, got := tree.lookup(nds[0], tc.key, c.params, c.skipNds, true) assert.Equal(t, tc.want, got) }) } } func TestRedirectTrailingSlash(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) cases := []struct { name string @@ -1374,7 +1503,7 @@ func TestRedirectTrailingSlash(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r := New(WithRedirectTrailingSlash(true)) - require.NoError(t, r.Tree().Handler(tc.method, tc.path, h)) + require.NoError(t, r.Tree().Handle(tc.method, tc.path, emptyHandler)) req := httptest.NewRequest(tc.method, tc.key, nil) w := httptest.NewRecorder() @@ -1389,7 +1518,6 @@ func TestRedirectTrailingSlash(t *testing.T) { } func TestRedirectFixedPath(t *testing.T) { - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) cases := []struct { name string path string @@ -1427,7 +1555,7 @@ func TestRedirectFixedPath(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r := New(WithRedirectFixedPath(true), WithRedirectTrailingSlash(tc.tsr)) - require.NoError(t, r.Tree().Handler(http.MethodGet, tc.path, h)) + require.NoError(t, r.Tree().Handle(http.MethodGet, tc.path, emptyHandler)) req, _ := http.NewRequest(http.MethodGet, tc.key, nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -1441,12 +1569,11 @@ func TestRedirectFixedPath(t *testing.T) { func TestTree_Remove(t *testing.T) { tree := New().Tree() - routes := make([]route, len(githubAPI)) copy(routes, githubAPI) for _, rte := range routes { - require.NoError(t, tree.Handler(rte.method, rte.path, emptyHandler)) + require.NoError(t, tree.Handle(rte.method, rte.path, emptyHandler)) } rand.Shuffle(len(routes), func(i, j int) { routes[i], routes[j] = routes[j], routes[i] }) @@ -1456,17 +1583,24 @@ func TestTree_Remove(t *testing.T) { } cnt := 0 - _ = Walk(tree, func(method, path string, handler Handler) error { + _ = Walk(tree, func(method, path string, handler HandlerFunc) error { cnt++ return nil }) assert.Equal(t, 0, cnt) + assert.Equal(t, 4, len(*tree.nodes.Load())) +} + +func TestTree_RemoveRoot(t *testing.T) { + tree := New().Tree() + require.NoError(t, tree.Handle(http.MethodOptions, "/foo/bar", emptyHandler)) + require.NoError(t, tree.Remove(http.MethodOptions, "/foo/bar")) + assert.Equal(t, 4, len(*tree.nodes.Load())) } func TestRouterWithAllowedMethod(t *testing.T) { r := New(WithHandleMethodNotAllowed(true)) - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) cases := []struct { name string @@ -1501,7 +1635,7 @@ func TestRouterWithAllowedMethod(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { - require.NoError(t, r.Tree().Handler(method, tc.path, h)) + require.NoError(t, r.Tree().Handle(method, tc.path, emptyHandler)) } req := httptest.NewRequest(tc.target, tc.path, nil) w := httptest.NewRecorder() @@ -1512,19 +1646,21 @@ func TestRouterWithAllowedMethod(t *testing.T) { } } -func TestPanicHandler(t *testing.T) { - r := New(WithPanicHandler(func(w http.ResponseWriter, r *http.Request, i interface{}) { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(i.(string))) - })) +func TestRecoveryMiddleware(t *testing.T) { + m := Recovery(func(c Context, err any) { + c.Writer().WriteHeader(http.StatusInternalServerError) + _, _ = c.Writer().Write([]byte(err.(string))) + }) + + r := New(WithMiddleware(m)) const errMsg = "unexpected error" - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) { + h := func(c Context) error { func() { panic(errMsg) }() - w.Write([]byte("foo")) - }) + return c.String(200, "foo") + } - require.NoError(t, r.Tree().Handler(http.MethodPost, "/", h)) + require.NoError(t, r.Tree().Handle(http.MethodPost, "/", h)) req := httptest.NewRequest(http.MethodPost, "/", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -1541,7 +1677,7 @@ func TestHas(t *testing.T) { r := New() for _, rte := range routes { - require.NoError(t, r.Handler(http.MethodGet, rte, emptyHandler)) + require.NoError(t, r.Handle(http.MethodGet, rte, emptyHandler)) } cases := []struct { @@ -1585,6 +1721,38 @@ func TestHas(t *testing.T) { } } +func TestErrorHandling(t *testing.T) { + r := New() + + req := httptest.NewRequest(http.MethodGet, "/foo/bar", nil) + w := httptest.NewRecorder() + + r.MustHandle(http.MethodGet, "/foo/bar", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, errors.New("oups")) + }) + + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Equal(t, "oups\n", w.Body.String()) +} + +func TestCustomErrorHandling(t *testing.T) { + r := New(WithRouteError(func(c Context, err error) { + http.Error(c.Writer(), err.Error(), http.StatusInternalServerError) + })) + + req := httptest.NewRequest(http.MethodGet, "/foo/bar", nil) + w := httptest.NewRecorder() + + r.MustHandle(http.MethodGet, "/foo/bar", func(c Context) error { + return errors.New("something went wrong") + }) + + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "something went wrong\n", w.Body.String()) +} + func TestReverse(t *testing.T) { routes := []string{ "/foo/bar", @@ -1594,7 +1762,7 @@ func TestReverse(t *testing.T) { r := New() for _, rte := range routes { - require.NoError(t, r.Handler(http.MethodGet, rte, emptyHandler)) + require.NoError(t, r.Handle(http.MethodGet, rte, emptyHandler)) } cases := []struct { @@ -1630,103 +1798,20 @@ func TestReverse(t *testing.T) { } } -func TestLookup(t *testing.T) { - routes := []string{ - "/foo/bar", - "/welcome/{name}", - "/users/uid_{id}", - "/john/doe/", - } - - r := New() - for _, rte := range routes { - require.NoError(t, r.Handler(http.MethodGet, rte, emptyHandler)) - } - - cases := []struct { - name string - path string - paramKey string - wantHandler bool - wantParamValue string - wantTsr bool - }{ - { - name: "matching static route", - path: "/foo/bar", - wantHandler: true, - }, - { - name: "tsr remove slash for static route", - path: "/foo/bar/", - wantTsr: true, - }, - { - name: "tsr add slash for static route", - path: "/john/doe", - wantTsr: true, - }, - { - name: "tsr for static route", - path: "/foo/bar/", - wantTsr: true, - }, - { - name: "matching params route", - path: "/welcome/fox", - wantHandler: true, - paramKey: "name", - wantParamValue: "fox", - }, - { - name: "tsr for params route", - path: "/welcome/fox/", - wantTsr: true, - }, - { - name: "matching mid route params", - path: "/users/uid_123", - wantHandler: true, - paramKey: "id", - wantParamValue: "123", - }, - { - name: "matching mid route params", - path: "/users/uid_123/", - wantTsr: true, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - handler, params, tsr := Lookup(r.Tree(), http.MethodGet, tc.path, false) - if params != nil { - defer params.Free(r.Tree()) - } - if tc.wantHandler { - assert.NotNil(t, handler) - } - assert.Equal(t, tc.wantTsr, tsr) - if tc.paramKey != "" { - require.NotNil(t, params) - assert.Equal(t, tc.wantParamValue, params.Get(tc.paramKey)) - } - }) - } -} - func TestAbortHandler(t *testing.T) { - r := New(WithPanicHandler(func(w http.ResponseWriter, r *http.Request, i interface{}) { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(i.(error).Error())) - })) + m := Recovery(func(c Context, err any) { + c.Writer().WriteHeader(http.StatusInternalServerError) + _, _ = c.Writer().Write([]byte(err.(error).Error())) + }) + + r := New(WithMiddleware(m)) - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) { + h := func(c Context) error { func() { panic(http.ErrAbortHandler) }() - w.Write([]byte("foo")) - }) + return c.String(200, "foo") + } - require.NoError(t, r.Tree().Handler(http.MethodPost, "/", h)) + require.NoError(t, r.Tree().Handle(http.MethodPost, "/", h)) req := httptest.NewRequest(http.MethodPost, "/", nil) w := httptest.NewRecorder() @@ -1751,7 +1836,6 @@ func TestFuzzInsertLookupParam(t *testing.T) { } tree := New().Tree() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) f := fuzz.New().NilChance(0).Funcs(unicodeRanges.CustomStringFuzzFunc()) routeFormat := "/%s/{%s}/%s/{%s}/{%s}" reqFormat := "/%s/%s/%s/%s/%s" @@ -1765,15 +1849,16 @@ func TestFuzzInsertLookupParam(t *testing.T) { if s1 == "" || s2 == "" || e1 == "" || e2 == "" || e3 == "" { continue } - if err := tree.insert(http.MethodGet, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), "", 3, h); err == nil { - nds := tree.load() + if err := tree.insert(http.MethodGet, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), "", 3, emptyHandler); err == nil { + nds := *tree.nodes.Load() - n, params, _ := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), false) + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c.params, c.skipNds, false) require.NotNil(t, n) assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.path) - assert.Equal(t, "xxxx", params.Get(e1)) - assert.Equal(t, "xxxx", params.Get(e2)) - assert.Equal(t, "xxxx", params.Get(e3)) + assert.Equal(t, "xxxx", c.Param(e1)) + assert.Equal(t, "xxxx", c.Param(e2)) + assert.Equal(t, "xxxx", c.Param(e3)) } } } @@ -1781,7 +1866,6 @@ func TestFuzzInsertLookupParam(t *testing.T) { func TestFuzzInsertNoPanics(t *testing.T) { f := fuzz.New().NilChance(0).NumElements(5000, 10000) tree := New().Tree() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) routes := make(map[string]struct{}) f.Fuzz(&routes) @@ -1789,11 +1873,11 @@ func TestFuzzInsertNoPanics(t *testing.T) { for rte := range routes { var catchAllKey string f.Fuzz(&catchAllKey) - if rte == "" && catchAllKey == "" { + if rte == "" { continue } require.NotPanicsf(t, func() { - _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, h) + _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, emptyHandler) }, fmt.Sprintf("rte: %s, catch all: %s", rte, catchAllKey)) } } @@ -1809,30 +1893,30 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { f := fuzz.New().NilChance(0).NumElements(1000, 2000).Funcs(unicodeRanges.CustomStringFuzzFunc()) tree := New().Tree() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, _ Params) {}) routes := make(map[string]struct{}) f.Fuzz(&routes) for rte := range routes { - err := tree.insert(http.MethodGet, "/"+rte, "", 0, h) + err := tree.insert(http.MethodGet, "/"+rte, "", 0, emptyHandler) require.NoError(t, err) } countPath := 0 - require.NoError(t, Walk(tree, func(method, path string, handler Handler) error { + require.NoError(t, Walk(tree, func(method, path string, handler HandlerFunc) error { countPath++ return nil })) assert.Equal(t, len(routes), countPath) for rte := range routes { - nds := tree.load() - n, _, _ := tree.lookup(nds[0], "/"+rte, true) + nds := *tree.nodes.Load() + c := newTestContextTree(tree) + n, _ := tree.lookup(nds[0], "/"+rte, c.params, c.skipNds, true) require.NotNilf(t, n, "route /%s", rte) require.Truef(t, n.isLeaf(), "route /%s", rte) require.Equal(t, "/"+rte, n.path) - require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", h)) + require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", emptyHandler)) } for rte := range routes { @@ -1841,7 +1925,7 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { } countPath = 0 - require.NoError(t, Walk(tree, func(method, path string, handler Handler) error { + require.NoError(t, Walk(tree, func(method, path string, handler HandlerFunc) error { countPath++ return nil })) @@ -1852,8 +1936,8 @@ func TestDataRace(t *testing.T) { var wg sync.WaitGroup start, wait := atomicSync() - h := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) {}) - newH := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) {}) + h := HandlerFunc(func(c Context) error { return nil }) + newH := HandlerFunc(func(c Context) error { return nil }) r := New() @@ -1863,15 +1947,29 @@ func TestDataRace(t *testing.T) { for _, rte := range githubAPI { go func(method, route string) { wait() - assert.NoError(t, r.Handler(method, route, h)) - // assert.NoError(t, r.Handler("PING", route, h)) - wg.Done() + defer wg.Done() + tree := r.Tree() + tree.Lock() + defer tree.Unlock() + if Has(tree, method, route) { + assert.NoError(t, tree.Update(method, route, h)) + return + } + assert.NoError(t, tree.Handle(method, route, h)) + // assert.NoError(t, r.Handle("PING", route, h)) }(rte.method, rte.path) go func(method, route string) { wait() - r.Update(method, route, newH) - wg.Done() + defer wg.Done() + tree := r.Tree() + tree.Lock() + defer tree.Unlock() + if Has(tree, method, route) { + assert.NoError(t, tree.Remove(method, route)) + return + } + assert.NoError(t, tree.Handle(method, route, newH)) }(rte.method, rte.path) go func(method, route string) { @@ -1888,32 +1986,32 @@ func TestDataRace(t *testing.T) { } func TestConcurrentRequestHandling(t *testing.T) { - r := New(WithSaveMatchedRoute(true)) + r := New() // /repos/{owner}/{repo}/keys - h1 := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - assert.Equal(t, "john", params.Get("owner")) - assert.Equal(t, "fox", params.Get("repo")) - _, _ = fmt.Fprint(w, params.Get(RouteKey)) + h1 := HandlerFunc(func(c Context) error { + assert.Equal(t, "john", c.Param("owner")) + assert.Equal(t, "fox", c.Param("repo")) + return c.String(200, c.Path()) }) // /repos/{owner}/{repo}/contents/*{path} - h2 := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - assert.Equal(t, "alex", params.Get("owner")) - assert.Equal(t, "vault", params.Get("repo")) - assert.Equal(t, "file.txt", params.Get("path")) - _, _ = fmt.Fprint(w, params.Get(RouteKey)) + h2 := HandlerFunc(func(c Context) error { + assert.Equal(t, "alex", c.Param("owner")) + assert.Equal(t, "vault", c.Param("repo")) + assert.Equal(t, "file.txt", c.Param("path")) + return c.String(200, c.Path()) }) // /users/{user}/received_events/public - h3 := HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - assert.Equal(t, "go", params.Get("user")) - _, _ = fmt.Fprint(w, params.Get(RouteKey)) + h3 := HandlerFunc(func(c Context) error { + assert.Equal(t, "go", c.Param("user")) + return c.String(200, c.Path()) }) - require.NoError(t, r.Handler(http.MethodGet, "/repos/{owner}/{repo}/keys", h1)) - require.NoError(t, r.Handler(http.MethodGet, "/repos/{owner}/{repo}/contents/*{path}", h2)) - require.NoError(t, r.Handler(http.MethodGet, "/users/{user}/received_events/public", h3)) + require.NoError(t, r.Handle(http.MethodGet, "/repos/{owner}/{repo}/keys", h1)) + require.NoError(t, r.Handle(http.MethodGet, "/repos/{owner}/{repo}/contents/*{path}", h2)) + require.NoError(t, r.Handle(http.MethodGet, "/users/{user}/received_events/public", h3)) r1 := httptest.NewRequest(http.MethodGet, "/repos/john/fox/keys", nil) r2 := httptest.NewRequest(http.MethodGet, "/repos/alex/vault/contents/file.txt", nil) @@ -1968,51 +2066,56 @@ func atomicSync() (start func(), wait func()) { return } -// When WithSaveMatchedRoute is enabled, the route matching the current request will be available in parameters. +// This example demonstrates how to create a simple router using the default options, +// which include the Recovery middleware. A basic route is defined, along with a +// custom middleware to log the request metrics. func ExampleNew() { - r := New(WithSaveMatchedRoute(true)) - metrics := func(next HandlerFunc) Handler { - return HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { + // Create a new router with default options, which include the Recovery middleware + r := New(DefaultOptions()) + + // Define a custom middleware to measure the time taken for request processing and + // log the URL, route, time elapsed, and status code + metrics := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { start := time.Now() - next.ServeHTTP(w, r, params) - log.Printf("url=%s; route=%s; time=%d", r.URL, params.Get(RouteKey), time.Since(start)) - }) + err := next(c) + log.Printf("url=%s; route=%s; time=%d; status=%d", c.Request().URL, c.Path(), time.Since(start), c.Writer().Status()) + return err + } } - _ = r.Handler(http.MethodGet, "/hello/{name}", metrics(func(w http.ResponseWriter, r *http.Request, params Params) { - _, _ = fmt.Fprintf(w, "Hello %s\n", params.Get("name")) + // Define a route with the path "/hello/{name}", apply the custom "metrics" middleware, + // and set a simple handler that greets the user by their name + r.MustHandle(http.MethodGet, "/hello/{name}", metrics(func(c Context) error { + return c.String(200, "Hello %s\n", c.Param("name")) })) + + // Start the HTTP server using the router as the handler and listen on port 8080 + log.Fatalln(http.ListenAndServe(":8080", r)) } -// This example demonstrates some important considerations when using the Lookup function. -func ExampleLookup() { - r := New() - _ = r.Handler(http.MethodGet, "/hello/{name}", HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - _, _ = fmt.Fprintf(w, "Hello, %s\n", params.Get("name")) - })) +// This example demonstrates how to register a global middleware that will be +// applied to all routes. - req := httptest.NewRequest(http.MethodGet, "/hello/fox", nil) +func ExampleWithMiddleware() { - // Each tree as its own sync.Pool that is used to reuse Params slice. Since the router tree may be swapped at - // any given time, it's recommended to copy the pointer locally so when the params is released, - // it returns to the correct pool. - tree := r.Tree() - handler, params, _ := Lookup(tree, http.MethodGet, req.URL.Path, false) - // If not nit, Params should be freed to reduce memory allocation. - if params != nil { - defer params.Free(tree) + // Define a custom middleware to measure the time taken for request processing and + // log the URL, route, time elapsed, and status code + metrics := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + start := time.Now() + err := next(c) + log.Printf("url=%s; route=%s; time=%d; status=%d", c.Request().URL, c.Path(), time.Since(start), c.Writer().Status()) + return err + } } - // Bad, instead make a local copy of the tree! - handler, params, _ = Lookup(r.Tree(), http.MethodGet, req.URL.Path, false) - if params != nil { - defer params.Free(r.Tree()) - } + r := New(WithMiddleware(metrics)) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req, nil) - fmt.Print(w.Body.String()) + r.MustHandle(http.MethodGet, "/hello/{name}", func(c Context) error { + return c.String(200, "Hello %s\n", c.Param("name")) + }) } // This example demonstrates some important considerations when using the Tree API. @@ -2023,26 +2126,29 @@ func ExampleRouter_Tree() { // any given time, you MUST always copy the pointer locally, This ensures that you do not inadvertently cause a // deadlock by locking/unlocking the wrong tree. tree := r.Tree() - upsert := func(method, path string, handler Handler) error { + upsert := func(method, path string, handler HandlerFunc) error { tree.Lock() defer tree.Unlock() if Has(tree, method, path) { return tree.Update(method, path, handler) } - return tree.Handler(method, path, handler) + return tree.Handle(method, path, handler) } - _ = upsert(http.MethodGet, "/foo/bar", HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - _, _ = fmt.Fprintln(w, "foo bar") - })) + _ = upsert(http.MethodGet, "/foo/bar", func(c Context) error { + // Note the tree accessible from fox.Context is already a local copy so the golden rule above does not apply. + c.Tree().Lock() + defer c.Tree().Unlock() + return c.String(200, "foo bar") + }) // Bad, instead make a local copy of the tree! - upsert = func(method, path string, handler Handler) error { + upsert = func(method, path string, handler HandlerFunc) error { r.Tree().Lock() defer r.Tree().Unlock() if Has(r.Tree(), method, path) { return r.Tree().Update(method, path, handler) } - return r.Tree().Handler(method, path, handler) + return r.Tree().Handle(method, path, handler) } } diff --git a/go.mod b/go.mod index 2fd25e8..5e7e85a 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/google/gofuzz v1.2.0 - github.com/stretchr/testify v1.8.0 + github.com/stretchr/testify v1.8.2 ) require ( diff --git a/go.sum b/go.sum index a26abda..66c3718 100644 --- a/go.sum +++ b/go.sum @@ -20,9 +20,11 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/http_consts.go b/http_consts.go new file mode 100644 index 0000000..5ba68bb --- /dev/null +++ b/http_consts.go @@ -0,0 +1,87 @@ +// Copyright (c) 2021 LabStack, see https://github.com/labstack/echo/blob/master/LICENSE. +// Portions of this code were derived from the Echo project (https://github.com/labstack/echo) +// under the MIT License. + +package fox + +// MIME types +const ( + charsetUTF8 = "charset=UTF-8" + MIMEApplicationJSON = "application/json" + MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 + MIMEApplicationJavaScript = "application/javascript" + MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 + MIMEApplicationXML = "application/xml" + MIMEApplicationXMLCharsetUTF8 = MIMEApplicationXML + "; " + charsetUTF8 + MIMETextXML = "text/xml" + MIMETextXMLCharsetUTF8 = MIMETextXML + "; " + charsetUTF8 + MIMEApplicationForm = "application/x-www-form-urlencoded" + MIMEApplicationProtobuf = "application/protobuf" + MIMEApplicationMsgpack = "application/msgpack" + MIMETextHTML = "text/html" + MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8 + MIMETextPlain = "text/plain" + MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8 + MIMEMultipartForm = "multipart/form-data" + MIMEOctetStream = "application/octet-stream" +) + +// Headers +const ( + HeaderAccept = "Accept" + HeaderAcceptEncoding = "Accept-Encoding" + // HeaderAllow is the name of the "Allow" header field used to list the set of methods + // advertised as supported by the target resource. Returning an Allow header is mandatory + // for status 405 (method not found) and useful for the OPTIONS method in responses. + // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 + HeaderAllow = "Allow" + HeaderAuthorization = "Authorization" + HeaderContentDisposition = "Content-Disposition" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderCookie = "Cookie" + HeaderSetCookie = "Set-Cookie" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderLastModified = "Last-Modified" + HeaderLocation = "Location" + HeaderRetryAfter = "Retry-After" + HeaderUpgrade = "Upgrade" + HeaderVary = "Vary" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderXForwardedSsl = "X-Forwarded-Ssl" + HeaderXUrlScheme = "X-Url-Scheme" + HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" + HeaderXRealIP = "X-Real-Ip" + HeaderXRequestID = "X-Request-Id" + HeaderXCorrelationID = "X-Correlation-Id" + HeaderXRequestedWith = "X-Requested-With" + HeaderServer = "Server" + HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" + + // Access control + HeaderAccessControlRequestMethod = "Access-Control-Request-Method" + HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" + HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" + HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" + HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" + HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" + HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" + HeaderAccessControlMaxAge = "Access-Control-Max-Age" + + // Security + HeaderStrictTransportSecurity = "Strict-Transport-Security" + HeaderXContentTypeOptions = "X-Content-Type-Options" + HeaderXXSSProtection = "X-XSS-Protection" + HeaderXFrameOptions = "X-Frame-Options" + HeaderContentSecurityPolicy = "Content-Security-Policy" + HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" + // nolint:gosec + HeaderXCSRFToken = "X-CSRF-Token" + HeaderReferrerPolicy = "Referrer-Policy" +) diff --git a/iter.go b/iter.go index 00e3819..9344deb 100644 --- a/iter.go +++ b/iter.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -24,7 +28,7 @@ func NewIterator(t *Tree) *Iterator { } func (it *Iterator) methods() map[string]*node { - nds := it.tree.load() + nds := *it.tree.nodes.Load() m := make(map[string]*node, len(nds)) for i := range nds { if len(nds[i].children) > 0 { @@ -174,13 +178,18 @@ func (it *Iterator) Method() string { } // Handler return the registered handler for the current route. -func (it *Iterator) Handler() Handler { +func (it *Iterator) Handler() HandlerFunc { if it.current != nil { return it.current.handler } return nil } +type stack struct { + method string + edges []*node +} + func newRawIterator(n *node) *rawIterator { return &rawIterator{ stack: []stack{{edges: []*node{n}}}, @@ -193,19 +202,6 @@ type rawIterator struct { stack []stack } -type stack struct { - method string - edges []*node -} - -func (it *rawIterator) fullPath() string { - return it.path -} - -func (it *rawIterator) node() *node { - return it.current -} - func (it *rawIterator) hasNext() bool { for len(it.stack) > 0 { n := len(it.stack) diff --git a/iter_test.go b/iter_test.go index 67479b5..8e79e59 100644 --- a/iter_test.go +++ b/iter_test.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -13,9 +17,9 @@ var routesCases = []string{"/fox/router", "/foo/bar/{baz}", "/foo/bar/{baz}/{nam func TestIterator_Rewind(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } results := make(map[string][]string) @@ -34,9 +38,9 @@ func TestIterator_Rewind(t *testing.T) { func TestIterator_SeekMethod(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } results := make(map[string][]string) @@ -54,9 +58,9 @@ func TestIterator_SeekMethod(t *testing.T) { func TestIterator_SeekPrefix(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } want := []string{"/foo/bar/{baz}", "/foo/bar/{baz}/{name}"} @@ -76,9 +80,9 @@ func TestIterator_SeekPrefix(t *testing.T) { func TestIterator_SeekMethodPrefix(t *testing.T) { tree := New().Tree() for _, rte := range routesCases { - require.NoError(t, tree.Handler(http.MethodGet, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodPost, rte, emptyHandler)) - require.NoError(t, tree.Handler(http.MethodHead, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodGet, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodPost, rte, emptyHandler)) + require.NoError(t, tree.Handle(http.MethodHead, rte, emptyHandler)) } want := []string{"/foo/bar/{baz}", "/foo/bar/{baz}/{name}"} diff --git a/node.go b/node.go index 2e29ff9..01ebe06 100644 --- a/node.go +++ b/node.go @@ -1,7 +1,12 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( "sort" + "strconv" "strings" "sync/atomic" ) @@ -9,7 +14,7 @@ import ( type node struct { // The registered handler matching the full path. Nil if the node is not a leaf. // Once assigned, handler is immutable. - handler Handler + handler HandlerFunc // key represent a segment of a route which share a common prefix with it parent. key string @@ -31,43 +36,40 @@ type node struct { // each pointer reference to a new child node starting with the same character. children []atomic.Pointer[node] - // Indicate whether its child node is a param node type. If true, len(children) == 1. - // Once assigned, paramChild is immutable. - paramChild bool + // The index of a paramChild if any, -1 if none (per rules, only one paramChildren is allowed). + paramChildIndex int } -func newNode(key string, handler Handler, children []*node, catchAllKey string, paramChild bool, path string) *node { +func newNode(key string, handler HandlerFunc, children []*node, catchAllKey string, path string) *node { sort.Slice(children, func(i, j int) bool { return children[i].key < children[j].key }) nds := make([]atomic.Pointer[node], len(children)) childKeys := make([]byte, len(children)) + childIndex := -1 for i := range children { assertNotNil(children[i]) childKeys[i] = children[i].key[0] nds[i].Store(children[i]) + if strings.HasPrefix(children[i].key, "{") { + childIndex = i + } } - return newNodeFromRef(key, handler, nds, childKeys, catchAllKey, paramChild, path) + return newNodeFromRef(key, handler, nds, childKeys, catchAllKey, childIndex, path) } -func newNodeFromRef(key string, handler Handler, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, paramChild bool, path string) *node { +func newNodeFromRef(key string, handler HandlerFunc, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int, path string) *node { n := &node{ - key: key, - childKeys: childKeys, - children: children, - handler: handler, - catchAllKey: catchAllKey, - path: path, - paramChild: paramChild, - } - // TODO find a better way - if catchAllKey != "" { - suffix := "*{" + catchAllKey + "}" - if !strings.HasSuffix(path, suffix) { - n.path += suffix - } + key: key, + childKeys: childKeys, + children: children, + handler: handler, + catchAllKey: catchAllKey, + path: appendCatchAll(path, catchAllKey), + paramChildIndex: childIndex, } + return n } @@ -178,8 +180,11 @@ func (n *node) string(space int) string { sb.WriteString(strings.Repeat(" ", space)) sb.WriteString("path: ") sb.WriteString(n.key) - if n.paramChild { - sb.WriteString(" [paramChild]") + + if n.paramChildIndex >= 0 { + sb.WriteString(" [paramIdx=") + sb.WriteString(strconv.Itoa(n.paramChildIndex)) + sb.WriteString("]") } if n.isCatchAll() { @@ -195,7 +200,30 @@ func (n *node) string(space int) string { children := n.getEdgesShallowCopy() for _, child := range children { sb.WriteString(" ") - sb.WriteString(child.string(space + 2)) + sb.WriteString(child.string(space + 4)) } return sb.String() } + +type skippedNodes []skippedNode + +func (n *skippedNodes) pop() skippedNode { + skipped := (*n)[len(*n)-1] + *n = (*n)[:len(*n)-1] + return skipped +} + +type skippedNode struct { + node *node + pathIndex int +} + +func appendCatchAll(path, catchAllKey string) string { + if catchAllKey != "" { + suffix := "*{" + catchAllKey + "}" + if !strings.HasSuffix(path, suffix) { + return path + suffix + } + } + return path +} diff --git a/options.go b/options.go index 2fd2f2e..438ea08 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,8 @@ -package fox +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. -import "net/http" +package fox type Option interface { apply(*Router) @@ -12,37 +14,46 @@ func (o optionFunc) apply(r *Router) { o(r) } -// WithNotFoundHandler register a http.Handler which is called when no matching route is found. -// By default, http.NotFound is used. -func WithNotFoundHandler(handler http.Handler) Option { +// WithRouteNotFound register an HandlerFunc which is called when no matching route is found. +// By default, the NotFoundHandler is used. +func WithRouteNotFound(handler HandlerFunc, m ...MiddlewareFunc) Option { return optionFunc(func(r *Router) { if handler != nil { - r.notFound = handler + r.noRoute = applyMiddleware(m, handler) } }) } -// WithNotAllowedHandler register a http.Handler which is called when the request cannot be routed, +// WithMethodNotAllowed register an HandlerFunc which is called when the request cannot be routed, // but the same route exist for other methods. The "Allow" header it automatically set -// before calling the handler. Mount WithHandleMethodNotAllowed to enable this option. By default, -// http.Error with http.StatusMethodNotAllowed is used. -func WithNotAllowedHandler(handler http.Handler) Option { +// before calling the handler. Set WithHandleMethodNotAllowed to enable this option. By default, +// the MethodNotAllowedHandler is used. +func WithMethodNotAllowed(handler HandlerFunc, m ...MiddlewareFunc) Option { return optionFunc(func(r *Router) { if handler != nil { - r.methodNotAllowed = handler + r.noMethod = applyMiddleware(m, handler) } }) } -// WithPanicHandler register a function to handle panics recovered from http handlers. -func WithPanicHandler(fn func(http.ResponseWriter, *http.Request, interface{})) Option { +// WithRouteError register an ErrorHandlerFunc which is called when an HandlerFunc returns an error. +// By default, the RouteErrorHandler is used. +func WithRouteError(handler ErrorHandlerFunc) Option { return optionFunc(func(r *Router) { - if fn != nil { - r.panicHandler = fn + if handler != nil { + r.errRoute = handler } }) } +// WithMiddleware attaches a global middleware to the router. Middlewares provided will be chained +// in the order they were added. Note that it does NOT apply the middlewares to the NotFound and MethodNotAllowed handlers. +func WithMiddleware(m ...MiddlewareFunc) Option { + return optionFunc(func(r *Router) { + r.mws = append(r.mws, m...) + }) +} + // WithHandleMethodNotAllowed enable to returns 405 Method Not Allowed instead of 404 Not Found // when the route exist for another http verb. func WithHandleMethodNotAllowed(enable bool) Option { @@ -71,10 +82,9 @@ func WithRedirectTrailingSlash(enable bool) Option { }) } -// WithSaveMatchedRoute configure the router to make the matched route accessible as a Handler parameter. -// Usage: p.Get(fox.RouteKey) -func WithSaveMatchedRoute(enable bool) Option { +// DefaultOptions configure the router to use the Recovery middleware. +func DefaultOptions() Option { return optionFunc(func(r *Router) { - r.saveMatchedRoute = enable + r.mws = append(r.mws, Recovery(defaultHandleRecovery)) }) } diff --git a/params.go b/params.go index a7dacb9..95830f5 100644 --- a/params.go +++ b/params.go @@ -1,15 +1,8 @@ -package fox - -import ( - "context" - "net/http" -) - -const RouteKey = "$k/fox" - -var ParamsKey = key{} +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. -type key struct{} +package fox type Param struct { Key string @@ -19,57 +12,18 @@ type Param struct { type Params []Param // Get the matching wildcard segment by name. -func (p *Params) Get(name string) string { - for i := range *p { - if (*p)[i].Key == name { - return (*p)[i].Value +func (p Params) Get(name string) string { + for i := range p { + if p[i].Key == name { + return p[i].Value } } return "" } // Clone make a copy of Params. -func (p *Params) Clone() Params { - cloned := make(Params, len(*p)) - copy(cloned, *p) +func (p Params) Clone() Params { + cloned := make(Params, len(p)) + copy(cloned, p) return cloned } - -// Free release the params to be reused later. -func (p *Params) Free(t *Tree) { - if cap(*p) < int(t.maxParams.Load()) { - return - } - *p = (*p)[:0] - t.p.Put(p) -} - -// ParamsFromContext is a helper function to retrieve parameters from the request context. -func ParamsFromContext(ctx context.Context) Params { - p, _ := ctx.Value(ParamsKey).(Params) - return p -} - -// WrapF is a helper function for wrapping http.HandlerFunc and returns a Fox Handler. -// Params are forwarded via the request context. See ParamsFromContext to retrieve parameters. -func WrapF(f http.HandlerFunc) Handler { - return HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - if len(params) > 0 { - f.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ParamsKey, params))) - return - } - f.ServeHTTP(w, r) - }) -} - -// WrapH is a helper function for wrapping http.Handler and returns a Fox Handler. -// Params are forwarded via the request context. See ParamsFromContext to retrieve parameters. -func WrapH(h http.Handler) Handler { - return HandlerFunc(func(w http.ResponseWriter, r *http.Request, params Params) { - if len(params) > 0 { - h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ParamsKey, params))) - return - } - h.ServeHTTP(w, r) - }) -} diff --git a/params_test.go b/params_test.go index e79d3e9..06aa842 100644 --- a/params_test.go +++ b/params_test.go @@ -1,91 +1,33 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" "testing" ) -func TestWrapHandler(t *testing.T) { - tree := New().Tree() - - cases := []struct { - name string - h Handler - params Params - }{ - { - name: "wrapf with params", - h: WrapF(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Equal(t, "bar", params.Get("foo")) - assert.Equal(t, "doe", params.Get("john")) - }), - params: Params{ - Param{ - Key: "foo", - Value: "bar", - }, - Param{ - Key: "john", - Value: "doe", - }, - }, - }, - { - name: "wraph with params", - h: WrapH(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Equal(t, "bar", params.Get("foo")) - assert.Equal(t, "doe", params.Get("john")) - })), - params: Params{ - Param{ - Key: "foo", - Value: "bar", - }, - Param{ - Key: "john", - Value: "doe", - }, - }, - }, - { - name: "wrapf no params", - h: WrapF(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Nil(t, params) - }), - params: nil, +func TestParams_Get(t *testing.T) { + params := make(Params, 0, 2) + params = append(params, + Param{ + Key: "foo", + Value: "bar", }, - { - name: "wraph no params", - h: WrapH(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := ParamsFromContext(r.Context()) - assert.Nil(t, params) - })), - params: nil, + Param{ + Key: "john", + Value: "doe", }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - params := tree.newParams() - defer params.Free(tree) - *params = append(*params, tc.params...) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - tc.h.ServeHTTP(nil, req, *params) - }) - } + ) + assert.Equal(t, "bar", params.Get("foo")) + assert.Equal(t, "doe", params.Get("john")) } -func TestParamsClone(t *testing.T) { - tree := New().Tree() - params := tree.newParams() - defer params.Free(tree) - *params = append(*params, +func TestParams_Clone(t *testing.T) { + params := make(Params, 0, 2) + params = append(params, Param{ Key: "foo", Value: "bar", @@ -95,5 +37,5 @@ func TestParamsClone(t *testing.T) { Value: "doe", }, ) - assert.Equal(t, *params, params.Clone()) + assert.Equal(t, params, params.Clone()) } diff --git a/path_test.go b/path_test.go new file mode 100644 index 0000000..748ceb9 --- /dev/null +++ b/path_test.go @@ -0,0 +1,149 @@ +// Copyright 2013 Julien Schmidt. All rights reserved. +// Based on the path package, Copyright 2009 The Go Authors. +// Use of this source code is governed by a BSD-style license that can be found +// in the LICENSE file. + +package fox + +import ( + "strings" + "testing" +) + +type cleanPathTest struct { + path, result string +} + +var cleanTests = []cleanPathTest{ + // Already clean + {"/", "/"}, + {"/abc", "/abc"}, + {"/a/b/c", "/a/b/c"}, + {"/abc/", "/abc/"}, + {"/a/b/c/", "/a/b/c/"}, + + // missing root + {"", "/"}, + {"a/", "/a/"}, + {"abc", "/abc"}, + {"abc/def", "/abc/def"}, + {"a/b/c", "/a/b/c"}, + + // Remove doubled slash + {"//", "/"}, + {"/abc//", "/abc/"}, + {"/abc/def//", "/abc/def/"}, + {"/a/b/c//", "/a/b/c/"}, + {"/abc//def//ghi", "/abc/def/ghi"}, + {"//abc", "/abc"}, + {"///abc", "/abc"}, + {"//abc//", "/abc/"}, + + // Remove . elements + {".", "/"}, + {"./", "/"}, + {"/abc/./def", "/abc/def"}, + {"/./abc/def", "/abc/def"}, + {"/abc/.", "/abc/"}, + + // Remove .. elements + {"..", "/"}, + {"../", "/"}, + {"../../", "/"}, + {"../..", "/"}, + {"../../abc", "/abc"}, + {"/abc/def/ghi/../jkl", "/abc/def/jkl"}, + {"/abc/def/../ghi/../jkl", "/abc/jkl"}, + {"/abc/def/..", "/abc"}, + {"/abc/def/../..", "/"}, + {"/abc/def/../../..", "/"}, + {"/abc/def/../../..", "/"}, + {"/abc/def/../../../ghi/jkl/../../../mno", "/mno"}, + + // Combinations + {"abc/./../def", "/def"}, + {"abc//./../def", "/def"}, + {"abc/../../././../def", "/def"}, +} + +func TestPathClean(t *testing.T) { + for _, test := range cleanTests { + if s := CleanPath(test.path); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result) + } + if s := CleanPath(test.result); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result) + } + } +} + +func TestPathCleanMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } + + for _, test := range cleanTests { + test := test + allocs := testing.AllocsPerRun(100, func() { CleanPath(test.result) }) + if allocs > 0 { + t.Errorf("CleanPath(%q): %v allocs, want zero", test.result, allocs) + } + } +} + +func BenchmarkPathClean(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, test := range cleanTests { + CleanPath(test.path) + } + } +} + +func genLongPaths() (testPaths []cleanPathTest) { + for i := 1; i <= 1234; i++ { + ss := strings.Repeat("a", i) + + correctPath := "/" + ss + testPaths = append(testPaths, cleanPathTest{ + path: correctPath, + result: correctPath, + }, cleanPathTest{ + path: ss, + result: correctPath, + }, cleanPathTest{ + path: "//" + ss, + result: correctPath, + }, cleanPathTest{ + path: "/" + ss + "/b/..", + result: correctPath, + }) + } + return testPaths +} + +func TestPathCleanLong(t *testing.T) { + cleanTests := genLongPaths() + + for _, test := range cleanTests { + if s := CleanPath(test.path); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.path, s, test.result) + } + if s := CleanPath(test.result); s != test.result { + t.Errorf("CleanPath(%q) = %q, want %q", test.result, s, test.result) + } + } +} + +func BenchmarkPathCleanLong(b *testing.B) { + cleanTests := genLongPaths() + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, test := range cleanTests { + CleanPath(test.path) + } + } +} diff --git a/recovery.go b/recovery.go new file mode 100644 index 0000000..d535100 --- /dev/null +++ b/recovery.go @@ -0,0 +1,61 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + +package fox + +import ( + "errors" + "log" + "net" + "net/http" + "os" + "runtime/debug" + "strings" +) + +var stdErr = log.New(os.Stderr, "", log.LstdFlags) + +// RecoveryFunc is a function type that defines how to handle panics that occur during the +// handling of an HTTP request. +type RecoveryFunc func(c Context, err any) + +// Recovery is a middleware that captures panics and recovers from them. It takes a custom handle function +// that will be called with the Context and the value recovered from the panic. +// Note that the middleware check if the panic is caused by http.ErrAbortHandler and re-panic if true +// allowing the http server to handle it as an abort. +func Recovery(handle RecoveryFunc) MiddlewareFunc { + return func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + defer recovery(c, handle) + return next(c) + } + } +} + +func recovery(c Context, handle RecoveryFunc) { + if err := recover(); err != nil { + if abortErr, ok := err.(error); ok && errors.Is(abortErr, http.ErrAbortHandler) { + panic(abortErr) + } + handle(c, err) + } +} + +func defaultHandleRecovery(c Context, err any) { + stdErr.Printf("[PANIC] %q panic recovered\n%s", err, debug.Stack()) + if !c.Writer().Written() && !connIsBroken(err) { + http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } +} + +func connIsBroken(err any) bool { + if ne, ok := err.(*net.OpError); ok { + var se *os.SyscallError + if errors.As(ne, &se) { + seStr := strings.ToLower(se.Error()) + return strings.Contains(seStr, "broken pipe") || strings.Contains(seStr, "connection reset by peer") + } + } + return false +} diff --git a/response_writer.go b/response_writer.go new file mode 100644 index 0000000..98cce9d --- /dev/null +++ b/response_writer.go @@ -0,0 +1,192 @@ +// ResponseRecorder is influenced by the work done by goji and chi libraries, +// with additional optimizations to avoid unnecessary memory allocations. +// See their respective licenses for more information: +// https://github.com/zenazn/goji/blob/master/LICENSE +// https://github.com/go-chi/chi/blob/master/LICENSE + +package fox + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" +) + +var _ http.Flusher = (*h1Writer)(nil) +var _ http.Hijacker = (*h1Writer)(nil) +var _ io.ReaderFrom = (*h1Writer)(nil) +var _ http.Pusher = (*h2Writer)(nil) +var _ http.Flusher = (*h2Writer)(nil) + +// ResponseWriter extends http.ResponseWriter and provides +// methods to retrieve the recorded status code, written state, and response size. +type ResponseWriter interface { + http.ResponseWriter + // Status recorded after Write and WriteHeader. + Status() int + // Written returns true if the response has been written. + Written() bool + // Size returns the size of the written response. + Size() int + // Unwrap returns the underlying http.ResponseWriter. + Unwrap() http.ResponseWriter +} + +const notWritten = -1 + +type recorder struct { + http.ResponseWriter + size int + status int +} + +func (r *recorder) reset(w http.ResponseWriter) { + r.ResponseWriter = w + r.size = notWritten + r.status = http.StatusOK +} + +func (r *recorder) Status() int { + return r.status +} + +func (r *recorder) Written() bool { + return r.size != notWritten +} + +func (r *recorder) Size() int { + return r.size +} + +func (r *recorder) Unwrap() http.ResponseWriter { + return r.ResponseWriter +} + +func (r *recorder) WriteHeader(code int) { + if !r.Written() { + r.size = 0 + r.status = code + } + r.ResponseWriter.WriteHeader(code) +} + +func (r *recorder) Write(buf []byte) (n int, err error) { + if !r.Written() { + r.size = 0 + r.ResponseWriter.WriteHeader(r.status) + } + n, err = r.ResponseWriter.Write(buf) + r.size += n + return +} + +func (r *recorder) WriteString(s string) (n int, err error) { + if !r.Written() { + r.size = 0 + r.ResponseWriter.WriteHeader(r.status) + } + n, err = io.WriteString(r.ResponseWriter, s) + r.size += n + return +} + +//nolint:unused +type hijackWriter struct { + *recorder +} + +//nolint:unused +func (w hijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !w.recorder.Written() { + w.recorder.size = 0 + } + return w.recorder.ResponseWriter.(http.Hijacker).Hijack() +} + +//nolint:unused +type flushHijackWriter struct { + *recorder +} + +//nolint:unused +func (w flushHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !w.recorder.Written() { + w.recorder.size = 0 + } + return w.recorder.ResponseWriter.(http.Hijacker).Hijack() +} + +//nolint:unused +func (w flushHijackWriter) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type flushWriter struct { + *recorder +} + +func (w flushWriter) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type h1Writer struct { + *recorder +} + +func (w h1Writer) ReadFrom(r io.Reader) (n int64, err error) { + rf := w.recorder.ResponseWriter.(io.ReaderFrom) + // If not written, status is OK + w.recorder.WriteHeader(w.recorder.status) + n, err = rf.ReadFrom(r) + w.recorder.size += int(n) + return +} + +func (w h1Writer) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !w.recorder.Written() { + w.recorder.size = 0 + } + return w.recorder.ResponseWriter.(http.Hijacker).Hijack() +} + +func (w h1Writer) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type h2Writer struct { + *recorder +} + +func (w h2Writer) Push(target string, opts *http.PushOptions) error { + return w.recorder.ResponseWriter.(http.Pusher).Push(target, opts) +} + +func (w h2Writer) Flush() { + if !w.recorder.Written() { + w.recorder.size = 0 + } + w.recorder.ResponseWriter.(http.Flusher).Flush() +} + +type noopWriter struct{} + +func (n noopWriter) Header() http.Header { + return make(http.Header) +} + +func (n noopWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("%w: writing on a clone", ErrDiscardedResponseWriter) +} + +func (n noopWriter) WriteHeader(int) {} diff --git a/test_helpers.go b/test_helpers.go new file mode 100644 index 0000000..00588d3 --- /dev/null +++ b/test_helpers.go @@ -0,0 +1,32 @@ +package fox + +import ( + "net/http" +) + +// NewTestContext returns a new Router and its associated Context, designed only for testing purpose. +func NewTestContext(w http.ResponseWriter, r *http.Request) (*Router, Context) { + fox := New() + c := NewTestContextOnly(fox, w, r) + return fox, c +} + +func NewTestContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) Context { + return newTextContextOnly(fox, w, r) +} + +func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *context { + c := fox.Tree().allocateContext() + c.resetNil() + c.rec.reset(w) + c.w = flushWriter{&c.rec} + c.fox = fox + c.req = r + return c +} + +func newTestContextTree(t *Tree) *context { + c := t.allocateContext() + c.resetNil() + return c +} diff --git a/tree.go b/tree.go index 1529f3f..de9a7c1 100644 --- a/tree.go +++ b/tree.go @@ -1,3 +1,7 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + package fox import ( @@ -11,10 +15,9 @@ import ( // The caller is responsible for ensuring that all writes are run serially. // // IMPORTANT: -// Each tree as its own sync.Mutex and sync.Pool that may be used to serialize write and reduce memory allocation. -// Since the router tree may be swapped at any given time, you MUST always copy the pointer locally -// to avoid inadvertently releasing Params to the wrong pool or worst, causing a deadlock by locking/unlocking the -// wrong Tree. +// Each tree as its own sync.Mutex that may be used to serialize write. Since the router tree may be swapped at any +// given time, you MUST always copy the pointer locally to avoid inadvertently causing a deadlock by locking/unlocking +// the wrong Tree. // // Good: // t := r.Tree() @@ -24,46 +27,39 @@ import ( // Dramatically bad, may cause deadlock // r.Tree().Lock() // defer r.Tree().Unlock() -// -// This principle also applies to the Lookup function, which requires releasing the Params slice, if not nil, by -// calling params.Free(tree). Always ensure that the Tree pointer passed as a parameter to params.Free is the same -// as the one passed to the Lookup function. type Tree struct { - p sync.Pool + ctx sync.Pool nodes atomic.Pointer[[]*node] + mws []MiddlewareFunc sync.Mutex maxParams atomic.Uint32 - saveRoute bool + maxDepth atomic.Uint32 } -// Handler registers a new handler for the given method and path. This function return an error if the route +// Handle registers a new handler for the given method and path. This function return an error if the route // is already registered or conflict with another. It's perfectly safe to add a new handler while the tree is in use // for serving requests. However, this function is NOT thread safe and should be run serially, along with // all other Tree's APIs. To override an existing route, use Update. -func (t *Tree) Handler(method, path string, handler Handler) error { +func (t *Tree) Handle(method, path string, handler HandlerFunc) error { p, catchAllKey, n, err := parseRoute(path) if err != nil { return err } - if t.saveRoute { - n += 1 - } - - return t.insert(method, p, catchAllKey, uint32(n), handler) + return t.insert(method, p, catchAllKey, uint32(n), applyMiddleware(t.mws, handler)) } // Update override an existing handler for the given method and path. If the route does not exist, // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. However, this function is NOT thread safe and should be run serially, along with -// all other Tree's APIs. To add new handler, use Handler method. -func (t *Tree) Update(method, path string, handler Handler) error { +// all other Tree's APIs. To add new handler, use Handle method. +func (t *Tree) Update(method, path string, handler HandlerFunc) error { p, catchAllKey, _, err := parseRoute(path) if err != nil { return err } - return t.update(method, p, catchAllKey, handler) + return t.update(method, p, catchAllKey, applyMiddleware(t.mws, handler)) } // Remove delete an existing handler for the given method and path. If the route does not exist, the function @@ -76,17 +72,18 @@ func (t *Tree) Remove(method, path string) error { } if !t.remove(method, path) { - return ErrRouteNotFound + return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) } return nil } -// insert is not safe for concurrent use. -func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler Handler) error { +// Insert is not safe for concurrent use. The path must start by '/' and it's not validated. Use +// parseRoute before. +func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler HandlerFunc) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. if method == "" { - return fmt.Errorf("http method is missing: %w", ErrInvalidRoute) + return fmt.Errorf("%w: http method is missing", ErrInvalidRoute) } var rootNode *node @@ -110,16 +107,19 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler // └── am // Create a new node from "st" reference and update the "te" (parent) reference to "st" node. if result.matched.isLeaf() { - return fmt.Errorf("route [%s] %s conflict: %w", method, path, ErrRouteExist) + if result.matched.isCatchAll() && isCatchAll { + return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) + } + return fmt.Errorf("%w: new route [%s] %s conflict with %s", ErrRouteExist, method, appendCatchAll(path, catchAllKey), result.matched.path) } // The matched node can only be the result of a previous split and therefore has children. - if isCatchAll { + if isCatchAll && result.matched.paramChildIndex >= 0 { return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) } // We are updating an existing node. We only need to create a new node from // the matched one with the updated/added value (handler and wildcard). - n := newNodeFromRef(result.matched.key, handler, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChild, path) + n := newNodeFromRef(result.matched.key, handler, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChildIndex, path) t.updateMaxParams(paramsN) result.p.updateEdge(n) @@ -140,21 +140,21 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler // 3. Update the "te" (parent) reference to the new "s" node (we are swapping old "st" to new "s" node, first // char remain the same). - if isCatchAll { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - } - keyCharsFromStartOfNodeFound := path[result.charsMatched-result.charsMatchedInNodeFound:] cPrefix := commonPrefix(keyCharsFromStartOfNodeFound, result.matched.key) suffixFromExistingEdge := strings.TrimPrefix(result.matched.key, cPrefix) + if strings.HasPrefix(suffixFromExistingEdge, "{") && isCatchAll { + return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) + } + child := newNodeFromRef( suffixFromExistingEdge, result.matched.handler, result.matched.children, result.matched.childKeys, result.matched.catchAllKey, - result.matched.paramChild, + result.matched.paramChildIndex, result.matched.path, ) @@ -163,18 +163,11 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler handler, []*node{child}, catchAllKey, - // e.g. tree encode /tes/{t} and insert /tes/ - // /tes/ (paramChild) - // ├── {t} - // since /tes/xyz will match until /tes/ and when looking for next child, 'x' will match nothing - // if paramChild == true { - // next = current.get(0) - // } - strings.HasPrefix(suffixFromExistingEdge, "{"), path, ) t.updateMaxParams(paramsN) + t.updateMaxDepth(result.depth + 1) result.p.updateEdge(parent) case incompleteMatchToEndOfEdge: // e.g. matched until "st" for "st" node but still have remaining char (ify) when inserting "testify" key. @@ -191,14 +184,14 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler // 2. Recreate the "st" node and link it to it's existing children and the new "ify" node. // 3. Update the "te" (parent) node to the new "st" node. - if result.matched.isCatchAll() { + keySuffix := path[result.charsMatched:] + + if strings.HasPrefix(keySuffix, "{") && result.matched.isCatchAll() { return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) } - keySuffix := path[result.charsMatched:] - // No children, so no paramChild - child := newNode(keySuffix, handler, nil, catchAllKey, false, path) + child := newNode(keySuffix, handler, nil, catchAllKey, path) edges := result.matched.getEdgesShallowCopy() edges = append(edges, child) n := newNode( @@ -206,20 +199,15 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler result.matched.handler, edges, result.matched.catchAllKey, - // e.g. tree encode /tes/ and insert /tes/{t} - // /tes/ (paramChild) - // ├── {t} - // since /tes/xyz will match until /tes/ and when looking for next child, 'x' will match nothing - // if paramChild == true { - // next = current.get(0) - // } - strings.HasPrefix(keySuffix, "{"), result.matched.path, ) + t.updateMaxDepth(result.depth + 1) t.updateMaxParams(paramsN) + if result.matched == rootNode { n.key = method + n.paramChildIndex = -1 t.updateRoot(n) break } @@ -258,32 +246,24 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler } suffixFromExistingEdge := strings.TrimPrefix(result.matched.key, cPrefix) - // Rule: parent's of a node with {param} have only one node or are prefixed by a char (e.g /{param}) - if strings.HasPrefix(suffixFromExistingEdge, "{") { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - } - keySuffix := path[result.charsMatched:] - // Rule: parent's of a node with {param} have only one node or are prefixed by a char (e.g /{param}) - if strings.HasPrefix(keySuffix, "{") { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) - } // No children, so no paramChild - n1 := newNodeFromRef(keySuffix, handler, nil, nil, catchAllKey, false, path) // inserted node + n1 := newNodeFromRef(keySuffix, handler, nil, nil, catchAllKey, -1, path) // inserted node n2 := newNodeFromRef( suffixFromExistingEdge, result.matched.handler, result.matched.children, result.matched.childKeys, result.matched.catchAllKey, - result.matched.paramChild, + result.matched.paramChildIndex, result.matched.path, ) // previous matched node // n3 children never start with a param - n3 := newNode(cPrefix, nil, []*node{n1, n2}, "", false, "") // intermediary node + n3 := newNode(cPrefix, nil, []*node{n1, n2}, "", "") // intermediary node + t.updateMaxDepth(result.depth + 1) t.updateMaxParams(paramsN) result.p.updateEdge(n3) default: @@ -294,21 +274,23 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler } // update is not safe for concurrent use. -func (t *Tree) update(method string, path, catchAllKey string, handler Handler) error { +func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFunc) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return fmt.Errorf("route [%s] %s is not registered: %w", method, path, ErrRouteNotFound) + return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) } result := t.search(nds[index], path) if !result.isExactMatch() || !result.matched.isLeaf() { - return fmt.Errorf("route [%s] %s is not registered: %w", method, path, ErrRouteNotFound) + return fmt.Errorf("%w: route [%s] %s is not registered", ErrRouteNotFound, method, path) } - if catchAllKey != "" && len(result.matched.children) > 0 { - return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)[1:]) + if catchAllKey != result.matched.catchAllKey { + err := newConflictErr(method, path, catchAllKey, []string{result.matched.path}) + err.isUpdate = true + return err } // We are updating an existing node (could be a leaf or not). We only need to create a new node from @@ -319,7 +301,7 @@ func (t *Tree) update(method string, path, catchAllKey string, handler Handler) result.matched.children, result.matched.childKeys, catchAllKey, - result.matched.paramChild, + result.matched.paramChildIndex, path, ) result.p.updateEdge(n) @@ -352,7 +334,7 @@ func (t *Tree) remove(method, path string) bool { result.matched.children, result.matched.childKeys, "", - result.matched.paramChild, + result.matched.paramChildIndex, "", ) result.p.updateEdge(n) @@ -368,7 +350,7 @@ func (t *Tree) remove(method, path string) bool { child.children, child.childKeys, child.catchAllKey, - child.paramChild, + child.paramChildIndex, child.path, ) result.p.updateEdge(n) @@ -397,7 +379,7 @@ func (t *Tree) remove(method, path string) bool { child.children, child.childKeys, child.catchAllKey, - child.paramChild, + child.paramChildIndex, child.path, ) } else { @@ -406,7 +388,6 @@ func (t *Tree) remove(method, path string) bool { result.p.handler, parentEdges, result.p.catchAllKey, - result.p.paramChild, result.p.path, ) } @@ -416,6 +397,7 @@ func (t *Tree) remove(method, path string) bool { return t.removeRoot(method) } parent.key = method + parent.paramChildIndex = -1 t.updateRoot(parent) return true } @@ -424,34 +406,37 @@ func (t *Tree) remove(method, path string) bool { return true } -func (t *Tree) lookup(rootNode *node, path string, lazy bool) (n *node, params *Params, tsr bool) { +const ( + slashDelim = '/' + bracketDelim = '{' +) + +func (t *Tree) lookup(rootNode *node, path string, params *Params, skipNds *skippedNodes, lazy bool) (n *node, tsr bool) { + if len(rootNode.children) == 0 { + return nil, false + } + var ( charsMatched int charsMatchedInNodeFound int + paramCnt uint32 ) - current := rootNode -STOP: - for charsMatched < len(path) { - idx := linearSearch(current.childKeys, path[charsMatched]) - if idx < 0 { - if !current.paramChild { - break - } - idx = 0 - } + current := rootNode.children[0].Load() + *skipNds = (*skipNds)[:0] - current = current.children[idx].Load() +Walk: + for charsMatched < len(path) { charsMatchedInNodeFound = 0 for i := 0; charsMatched < len(path); i++ { if i >= len(current.key) { break } - if current.key[i] != path[charsMatched] || path[charsMatched] == '{' { + if current.key[i] != path[charsMatched] || path[charsMatched] == bracketDelim { if current.key[i] == '{' { startPath := charsMatched - idx := strings.IndexByte(path[charsMatched:], '/') + idx := strings.IndexByte(path[charsMatched:], slashDelim) if idx > 0 { // There is another path segment (e.g. /foo/{bar}/baz) charsMatched += idx @@ -460,10 +445,11 @@ STOP: charsMatched += len(path[charsMatched:]) } else { // segment is empty - break STOP + break Walk } + startKey := charsMatchedInNodeFound - idx = strings.IndexByte(current.key[startKey:], '/') + idx = strings.IndexByte(current.key[startKey:], slashDelim) if idx >= 0 { // -1 since on the next incrementation, if any, 'i' are going to be incremented i += idx - 1 @@ -473,51 +459,84 @@ STOP: i += len(current.key[charsMatchedInNodeFound:]) - 1 charsMatchedInNodeFound += len(current.key[charsMatchedInNodeFound:]) } + if !lazy { - if params == nil { - params = t.newParams() - } - // :n where n > 0 + paramCnt++ *params = append(*params, Param{Key: current.key[startKey+1 : charsMatchedInNodeFound-1], Value: path[startPath:charsMatched]}) } + continue } - break STOP + + break Walk } charsMatched++ charsMatchedInNodeFound++ } + + if charsMatched < len(path) { + // linear search + idx := -1 + for i := 0; i < len(current.childKeys); i++ { + if current.childKeys[i] == path[charsMatched] { + idx = i + break + } + } + + if idx < 0 { + if current.paramChildIndex < 0 { + break + } + // child param: go deeper and since the child param is evaluated + // now, no need to backtrack later. + idx = current.paramChildIndex + current = current.children[idx].Load() + continue + } + + if current.paramChildIndex >= 0 || current.isCatchAll() { + *skipNds = append(*skipNds, skippedNode{current, charsMatched}) + paramCnt = 0 + } + current = current.children[idx].Load() + } } + hasSkpNds := len(*skipNds) > 0 + if !current.isLeaf() { - return nil, params, false + if hasSkpNds { + goto Backtrack + } + + return nil, false } if charsMatched == len(path) { if charsMatchedInNodeFound == len(current.key) { - // Exact match, note that if we match a wildcard node, the param value is always '/' - if !lazy && (t.saveRoute || current.isCatchAll()) { - if params == nil { - params = t.newParams() - } - - if t.saveRoute { - *params = append(*params, Param{Key: RouteKey, Value: current.path}) - } - - if current.isCatchAll() { - *params = append(*params, Param{Key: current.catchAllKey, Value: path[charsMatched:]}) - } - - return current, params, false + // Exact match, note that if we match a catch all node + if !lazy && current.isCatchAll() { + *params = append(*params, Param{Key: current.catchAllKey, Value: path[charsMatched:]}) + return current, false } - return current, params, false - } else if charsMatchedInNodeFound < len(current.key) { + + return current, false + } + if charsMatchedInNodeFound < len(current.key) { // Key end mid-edge // Tsr recommendation: add an extra trailing slash (got an exact match) - remainingSuffix := current.key[charsMatchedInNodeFound:] - return nil, nil, len(remainingSuffix) == 1 && remainingSuffix[0] == '/' + if !tsr { + remainingSuffix := current.key[charsMatchedInNodeFound:] + tsr = len(remainingSuffix) == 1 && remainingSuffix[0] == slashDelim + } + + if hasSkpNds { + goto Backtrack + } + + return nil, tsr } } @@ -525,24 +544,53 @@ STOP: if charsMatched < len(path) && charsMatchedInNodeFound == len(current.key) { if current.isCatchAll() { if !lazy { - if params == nil { - params = t.newParams() - } *params = append(*params, Param{Key: current.catchAllKey, Value: path[charsMatched:]}) - if t.saveRoute { - *params = append(*params, Param{Key: RouteKey, Value: current.path}) - } - return current, params, false + return current, false } // Same as exact match, no tsr recommendation - return current, params, false + return current, false } + // Tsr recommendation: remove the extra trailing slash (got an exact match) - remainingKeySuffix := path[charsMatched:] - return nil, nil, len(remainingKeySuffix) == 1 && remainingKeySuffix[0] == '/' + if !tsr { + remainingKeySuffix := path[charsMatched:] + tsr = len(remainingKeySuffix) == 1 && remainingKeySuffix[0] == slashDelim + } + + if hasSkpNds { + goto Backtrack + } + + return nil, tsr } - return nil, nil, false + // Finally incomplete match to middle of ege +Backtrack: + if hasSkpNds { + skipped := skipNds.pop() + if skipped.node.paramChildIndex < 0 { + // skipped is catch all + current = skipped.node + *params = (*params)[:len(*params)-int(paramCnt)] + + if !lazy { + *params = append(*params, Param{Key: current.catchAllKey, Value: path[skipped.pathIndex:]}) + + return current, false + } + + return current, false + } + + current = skipped.node.children[skipped.node.paramChildIndex].Load() + + *params = (*params)[:len(*params)-int(paramCnt)] + charsMatched = skipped.pathIndex + paramCnt = 0 + goto Walk + } + + return nil, tsr } func (t *Tree) search(rootNode *node, path string) searchResult { @@ -553,6 +601,7 @@ func (t *Tree) search(rootNode *node, path string) searchResult { p *node charsMatched int charsMatchedInNodeFound int + depth uint32 ) STOP: @@ -562,6 +611,7 @@ STOP: break STOP } + depth++ pp = p p = current current = next @@ -587,10 +637,24 @@ STOP: charsMatchedInNodeFound: charsMatchedInNodeFound, p: p, pp: pp, + depth: depth, + } +} + +func (t *Tree) allocateContext() *context { + params := make(Params, 0, t.maxParams.Load()) + skipNds := make(skippedNodes, 0, t.maxDepth.Load()) + return &context{ + params: ¶ms, + skipNds: &skipNds, + // This is a read only value, no reset, it's always the + // owner of the pool. + tree: t, } } -// addRoot is not safe for concurrent use. +// addRoot append a new root node to the tree. +// Note: This function should be guarded by mutex. func (t *Tree) addRoot(n *node) { nds := *t.nodes.Load() newNds := make([]*node, 0, len(nds)+1) @@ -599,9 +663,15 @@ func (t *Tree) addRoot(n *node) { t.nodes.Store(&newNds) } -// updateRoot is not safe for concurrent use. +// updateRoot replaces a root node in the tree. +// Due to performance optimization, the tree uses atomic.Pointer[[]*node] instead of +// atomic.Pointer[atomic.Pointer[*node]]. As a result, the root node cannot be replaced +// directly by swapping the pointer. Instead, a new list of nodes is created with the +// updated root node, and the entire list is swapped afterwards. +// Note: This function should be guarded by mutex. func (t *Tree) updateRoot(n *node) bool { nds := *t.nodes.Load() + // for root node, the key contains the HTTP verb. index := findRootNode(n.key, nds) if index < 0 { return false @@ -614,7 +684,8 @@ func (t *Tree) updateRoot(n *node) bool { return true } -// removeRoot is not safe for concurrent use. +// removeRoot remove a root nod from the tree. +// Note: This function should be guarded by mutex. func (t *Tree) removeRoot(method string) bool { nds := *t.nodes.Load() index := findRootNode(method, nds) @@ -636,10 +707,10 @@ func (t *Tree) updateMaxParams(max uint32) { } } -func (t *Tree) load() []*node { - return *t.nodes.Load() -} - -func (t *Tree) newParams() *Params { - return t.p.Get().(*Params) +// updateMaxDepth perform an update only if max is greater than the current +// max depth. This function should be guarded my mutex. +func (t *Tree) updateMaxDepth(max uint32) { + if max > t.maxDepth.Load() { + t.maxDepth.Store(max) + } }