From 36e451a74f80da9d9d609dc4fa67dc4e3f2bd42d Mon Sep 17 00:00:00 2001 From: tigerwill90 Date: Sat, 5 Oct 2024 18:50:50 +0200 Subject: [PATCH] feat: wip per route options --- fox.go | 22 +++++++++------ fox_test.go | 34 +++++++++++++--------- iter.go | 6 ++-- node.go | 26 ++++++++--------- tree.go | 81 ++++++++++++++++++++++++++++++----------------------- 5 files changed, 96 insertions(+), 73 deletions(-) diff --git a/fox.go b/fox.go index 383b57d..b115604 100644 --- a/fox.go +++ b/fox.go @@ -70,6 +70,7 @@ func (f ClientIPStrategyFunc) ClientIP(c Context) (*net.IPAddr, error) { // Route represent a registered route in the route tree. type Route struct { ipStrategy ClientIPStrategy + base HandlerFunc handler HandlerFunc path string mws []middleware @@ -77,8 +78,13 @@ type Route struct { ignoreTrailingSlash bool } -// Handle call the registered handler with the provided Context. +// Handle calls the base handler with the provided Context. func (r *Route) Handle(c Context) { + r.base(c) +} + +// HandleWithMiddleware calls the handler with applied middleware using the provided Context. +func (r *Route) HandleWithMiddleware(c Context) { r.handler(c) } @@ -298,7 +304,7 @@ Next: method := nds[i].key it := newRawIterator(nds[i]) for it.hasNext() { - err := fn(method, it.path, it.current.handler) + err := fn(method, it.path, it.current.route.handler) if err != nil { if errors.Is(err, SkipMethod) { continue Next @@ -384,9 +390,9 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { n, tsr = tree.lookup(nds[index], target, c, false) if !tsr && n != nil { - c.path = n.path + c.path = n.route.path c.tsr = tsr - n.handler(c) + n.route.handler(c) // Put back the context, if not extended more than max params or max depth, allowing // the slice to naturally grow within the constraint. if cap(*c.params) <= int(tree.maxParams.Load()) && cap(*c.skipNds) <= int(tree.maxDepth.Load()) { @@ -397,9 +403,9 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodConnect && r.URL.Path != "/" && tsr { if fox.ignoreTrailingSlash { - c.path = n.path + c.path = n.route.path c.tsr = tsr - n.handler(c) + n.route.handler(c) c.Close() return } @@ -645,7 +651,7 @@ func getRouteConflict(n *node) []string { routes := make([]string, 0) if n.isCatchAll() { - routes = append(routes, n.path) + routes = append(routes, n.route.path) return routes } @@ -654,7 +660,7 @@ func getRouteConflict(n *node) []string { } it := newRawIterator(n) for it.hasNext() { - routes = append(routes, it.current.path) + routes = append(routes, it.current.route.path) } return routes } diff --git a/fox_test.go b/fox_test.go index f129171..c990462 100644 --- a/fox_test.go +++ b/fox_test.go @@ -750,8 +750,9 @@ func TestRouteWithParams(t *testing.T) { c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], rte, c, false) require.NotNil(t, n) + require.NotNil(t, n.route) assert.False(t, tsr) - assert.Equal(t, rte, n.path) + assert.Equal(t, rte, n.route.path) } } @@ -1196,9 +1197,9 @@ func TestOverlappingRoute(t *testing.T) { c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], tc.path, c, false) require.NotNil(t, n) - require.NotNil(t, n.handler) + require.NotNil(t, n.route) assert.False(t, tsr) - assert.Equal(t, tc.wantMatch, n.path) + assert.Equal(t, tc.wantMatch, n.route.path) if len(tc.wantParams) == 0 { assert.Empty(t, c.Params()) } else { @@ -1209,10 +1210,10 @@ func TestOverlappingRoute(t *testing.T) { c = newTestContextTree(tree) n, tsr = tree.lookup(nds[0], tc.path, c, true) require.NotNil(t, n) - require.NotNil(t, n.handler) + require.NotNil(t, n.route) assert.False(t, tsr) assert.Empty(t, c.Params()) - assert.Equal(t, tc.wantMatch, n.path) + assert.Equal(t, tc.wantMatch, n.route.path) }) } } @@ -1643,7 +1644,7 @@ func TestTree_LookupTsr(t *testing.T) { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() for _, path := range tc.paths { - require.NoError(t, tree.insert(http.MethodGet, path, "", 0, emptyHandler)) + require.NoError(t, tree.insert(http.MethodGet, path, "", 0, tree.newRoute(path, emptyHandler))) } nds := *tree.nodes.Load() c := newTestContextTree(tree) @@ -1651,7 +1652,8 @@ func TestTree_LookupTsr(t *testing.T) { assert.Equal(t, tc.want, got) if tc.want { require.NotNil(t, n) - assert.Equal(t, tc.wantPath, n.path) + require.NotNil(t, n.route) + assert.Equal(t, tc.wantPath, n.route.path) } }) } @@ -2060,7 +2062,6 @@ func TestRouterWithTsrParams(t *testing.T) { f := New(WithIgnoreTrailingSlash(true)) for _, rte := range tc.routes { require.NoError(t, f.Handle(http.MethodGet, rte, func(c Context) { - fmt.Println(c.Path(), c.Params()) assert.Equal(t, tc.wantPath, c.Path()) assert.Equal(t, tc.wantParams, c.Params()) assert.Equal(t, tc.wantTsr, unwrapContext(t, c).tsr) @@ -2804,14 +2805,16 @@ func TestFuzzInsertLookupParam(t *testing.T) { if s1 == "" || s2 == "" || e1 == "" || e2 == "" || e3 == "" { continue } - if err := tree.insert(http.MethodGet, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), "", 3, emptyHandler); err == nil { + path := fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3) + if err := tree.insert(http.MethodGet, path, "", 3, tree.newRoute(path, emptyHandler)); err == nil { nds := *tree.nodes.Load() c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c, false) require.NotNil(t, n) + require.NotNil(t, n.route) assert.False(t, tsr) - assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.path) + assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.route.path) assert.Equal(t, "xxxx", c.Param(e1)) assert.Equal(t, "xxxx", c.Param(e2)) assert.Equal(t, "xxxx", c.Param(e3)) @@ -2833,7 +2836,7 @@ func TestFuzzInsertNoPanics(t *testing.T) { continue } require.NotPanicsf(t, func() { - _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, emptyHandler) + _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, tree.newRoute(appendCatchAll(rte, catchAllKey), emptyHandler)) }, fmt.Sprintf("rte: %s, catch all: %s", rte, catchAllKey)) } } @@ -2854,7 +2857,8 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { f.Fuzz(&routes) for rte := range routes { - err := tree.insert(http.MethodGet, "/"+rte, "", 0, emptyHandler) + path := "/" + rte + err := tree.insert(http.MethodGet, path, "", 0, tree.newRoute(path, emptyHandler)) require.NoError(t, err) } @@ -2870,10 +2874,12 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], "/"+rte, c, true) require.NotNilf(t, n, "route /%s", rte) + require.NotNilf(t, n.route, "route /%s", rte) require.Falsef(t, tsr, "tsr: %t", tsr) require.Truef(t, n.isLeaf(), "route /%s", rte) - require.Equal(t, "/"+rte, n.path) - require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", emptyHandler)) + require.Equal(t, "/"+rte, n.route.path) + path := "/" + rte + require.NoError(t, tree.update(http.MethodGet, path, "", tree.newRoute(path, emptyHandler))) } for rte := range routes { diff --git a/iter.go b/iter.go index 9344deb..bfe092b 100644 --- a/iter.go +++ b/iter.go @@ -167,7 +167,7 @@ func (it *Iterator) Next() { // Path returns the registered path for the current route. func (it *Iterator) Path() string { if it.current != nil { - return it.current.path + return it.current.route.path } return "" } @@ -180,7 +180,7 @@ func (it *Iterator) Method() string { // Handler return the registered handler for the current route. func (it *Iterator) Handler() HandlerFunc { if it.current != nil { - return it.current.handler + return it.current.route.handler } return nil } @@ -221,7 +221,7 @@ func (it *rawIterator) hasNext() bool { it.current = elem if it.current.isLeaf() { - it.path = elem.path + it.path = elem.route.Path() return true } } diff --git a/node.go b/node.go index da8b570..f2ebce8 100644 --- a/node.go +++ b/node.go @@ -12,9 +12,9 @@ import ( ) type node struct { - // The registered handler matching the full path. Nil if the node is not a leaf. - // Once assigned, handler is immutable. - handler HandlerFunc + // The registered route matching the full path. Nil if the node is not a leaf. + // Once assigned, route is immutable. + route *Route // key represent a segment of a route which share a common prefix with it parent. key string @@ -23,9 +23,6 @@ type node struct { // Once assigned, catchAllKey is immutable. catchAllKey string - // The full path when it's a leaf node - path string - // First char of each outgoing edges from this node sorted in ascending order. // Once assigned, this is a read only slice. It allows to lazily traverse the // tree without the extra cost of atomic load operation. @@ -42,7 +39,11 @@ type node struct { paramChildIndex int } -func newNode(key string, handler HandlerFunc, children []*node, catchAllKey string, path string) *node { +func newNode(key string, route *Route, children []*node, catchAllKey string) *node { + // TODO use this instead of old sort.Slice + /* slices.SortFunc(children, func(a, b *node) int { + return cmp.Compare(a.key, b.key) + })*/ sort.Slice(children, func(i, j int) bool { return children[i].key < children[j].key }) @@ -58,24 +59,23 @@ func newNode(key string, handler HandlerFunc, children []*node, catchAllKey stri } } - return newNodeFromRef(key, handler, nds, childKeys, catchAllKey, childIndex, path) + return newNodeFromRef(key, route, nds, childKeys, catchAllKey, childIndex) } -func newNodeFromRef(key string, handler HandlerFunc, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int, path string) *node { +func newNodeFromRef(key string, route *Route, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int) *node { return &node{ key: key, childKeys: childKeys, children: children, - handler: handler, + route: route, catchAllKey: catchAllKey, - path: appendCatchAll(path, catchAllKey), paramChildIndex: childIndex, params: parseWildcard(key), } } func (n *node) isLeaf() bool { - return n.handler != nil + return n.route != nil } func (n *node) isCatchAll() bool { @@ -211,7 +211,7 @@ func (n *node) string(space int) string { } if n.isLeaf() { sb.WriteString(" [leaf=") - sb.WriteString(n.path) + sb.WriteString(n.route.path) sb.WriteString("]") } if n.hasWildcard() { diff --git a/tree.go b/tree.go index 8c7b4aa..384cec6 100644 --- a/tree.go +++ b/tree.go @@ -33,7 +33,7 @@ import ( type Tree struct { ctx sync.Pool nodes atomic.Pointer[[]*node] - fox *Router + fox *Router // TODO tree should be agnostic to the router mws []middleware sync.Mutex maxParams atomic.Uint32 @@ -44,7 +44,7 @@ type Tree struct { // is already registered or conflict with another. It's perfectly safe to add a new handler while the tree is in use // for serving requests. However, this function is NOT thread-safe and should be run serially, along with all other // Tree APIs that perform write operations. To override an existing route, use Update. -func (t *Tree) Handle(method, path string, handler HandlerFunc) error { +func (t *Tree) Handle(method, path string, handler HandlerFunc, opts ...PathOption) error { if matched := regEnLetter.MatchString(method); !matched { return fmt.Errorf("%w: missing or invalid http method", ErrInvalidRoute) } @@ -54,14 +54,14 @@ func (t *Tree) Handle(method, path string, handler HandlerFunc) error { return err } - return t.insert(method, p, catchAllKey, uint32(n), applyMiddleware(RouteHandlers, t.mws, handler)) + return t.insert(method, p, catchAllKey, uint32(n), t.newRoute(path, handler, opts...)) } // Update override an existing handler for the given method and path. If the route does not exist, // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. However, this function is NOT thread-safe and should be run serially, along with all other // Tree APIs that perform write operations. To add a new handler, use Handle method. -func (t *Tree) Update(method, path string, handler HandlerFunc) error { +func (t *Tree) Update(method, path string, handler HandlerFunc, opts ...PathOption) error { if method == "" { return fmt.Errorf("%w: missing http method", ErrInvalidRoute) } @@ -71,7 +71,7 @@ func (t *Tree) Update(method, path string, handler HandlerFunc) error { return err } - return t.update(method, p, catchAllKey, applyMiddleware(RouteHandlers, t.mws, handler)) + return t.update(method, p, catchAllKey, t.newRoute(path, handler, opts...)) } // Remove delete an existing handler for the given method and path. If the route does not exist, the function @@ -110,7 +110,7 @@ func (t *Tree) Has(method, path string) bool { n, tsr := t.lookup(nds[index], path, c, true) c.Close() if n != nil && !tsr { - return n.path == path + return n.route.path == path } return false } @@ -131,8 +131,9 @@ func (t *Tree) Match(method, path string) string { c.resetNil() n, tsr := t.lookup(nds[index], path, c, true) c.Close() - if n != nil && (!tsr || t.fox.redirectTrailingSlash || t.fox.ignoreTrailingSlash) { - return n.path + // TODO maybe returns tsr ??? + if n != nil && (!tsr || n.route.redirectTrailingSlash || n.route.ignoreTrailingSlash) { + return n.route.path } return "" } @@ -201,9 +202,9 @@ func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, c n, tsr := t.lookup(nds[index], target, c, false) if n != nil { - c.path = n.path + c.path = n.route.path c.tsr = tsr - return n.handler, c, tsr + return n.route.base, c, tsr } c.Close() return nil, nil, tsr @@ -211,7 +212,7 @@ func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, c // Insert is not safe for concurrent use. The path must start by '/' and it's not validated. Use // parseRoute before. -func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler HandlerFunc) error { +func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, route *Route) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. var rootNode *node nds := *t.nodes.Load() @@ -237,12 +238,12 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler if result.matched.isCatchAll() && isCatchAll { return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) } - return fmt.Errorf("%w: new route %s %s conflict with %s", ErrRouteExist, method, appendCatchAll(path, catchAllKey), result.matched.path) + return fmt.Errorf("%w: new route %s %s conflict with %s", ErrRouteExist, method, route.path, result.matched.route.path) } // We are updating an existing node. We only need to create a new node from // the matched one with the updated/added value (handler and wildcard). - n := newNodeFromRef(result.matched.key, handler, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChildIndex, path) + n := newNodeFromRef(result.matched.key, route, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChildIndex) t.updateMaxParams(paramsN) result.p.updateEdge(n) @@ -269,20 +270,18 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler child := newNodeFromRef( suffixFromExistingEdge, - result.matched.handler, + result.matched.route, result.matched.children, result.matched.childKeys, result.matched.catchAllKey, result.matched.paramChildIndex, - result.matched.path, ) parent := newNode( cPrefix, - handler, + route, []*node{child}, catchAllKey, - path, ) t.updateMaxParams(paramsN) @@ -306,15 +305,14 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler keySuffix := path[result.charsMatched:] // No children, so no paramChild - child := newNode(keySuffix, handler, nil, catchAllKey, path) + child := newNode(keySuffix, route, nil, catchAllKey) edges := result.matched.getEdgesShallowCopy() edges = append(edges, child) n := newNode( result.matched.key, - result.matched.handler, + result.matched.route, edges, result.matched.catchAllKey, - result.matched.path, ) t.updateMaxDepth(result.depth + 1) @@ -364,19 +362,18 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler keySuffix := path[result.charsMatched:] // No children, so no paramChild - n1 := newNodeFromRef(keySuffix, handler, nil, nil, catchAllKey, -1, path) // inserted node + n1 := newNodeFromRef(keySuffix, route, nil, nil, catchAllKey, -1) // inserted node n2 := newNodeFromRef( suffixFromExistingEdge, - result.matched.handler, + result.matched.route, result.matched.children, result.matched.childKeys, result.matched.catchAllKey, result.matched.paramChildIndex, - result.matched.path, ) // previous matched node // n3 children never start with a param - n3 := newNode(cPrefix, nil, []*node{n1, n2}, "", "") // intermediary node + n3 := newNode(cPrefix, nil, []*node{n1, n2}, "") // intermediary node t.updateMaxDepth(result.depth + 1) t.updateMaxParams(paramsN) @@ -389,7 +386,7 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler } // update is not safe for concurrent use. -func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFunc) error { +func (t *Tree) update(method string, path, catchAllKey string, route *Route) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. nds := *t.nodes.Load() index := findRootNode(method, nds) @@ -403,7 +400,7 @@ func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFu } if catchAllKey != result.matched.catchAllKey { - err := newConflictErr(method, path, catchAllKey, []string{result.matched.path}) + err := newConflictErr(method, path, catchAllKey, []string{result.matched.route.path}) err.isUpdate = true return err } @@ -412,12 +409,11 @@ func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFu // the matched one with the updated/added value (handler and wildcard). n := newNodeFromRef( result.matched.key, - handler, + route, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChildIndex, - path, ) result.p.updateEdge(n) return nil @@ -450,7 +446,6 @@ func (t *Tree) remove(method, path string) bool { result.matched.childKeys, "", result.matched.paramChildIndex, - "", ) result.p.updateEdge(n) return true @@ -461,12 +456,11 @@ func (t *Tree) remove(method, path string) bool { mergedPath := fmt.Sprintf("%s%s", result.matched.key, child.key) n := newNodeFromRef( mergedPath, - child.handler, + child.route, child.children, child.childKeys, child.catchAllKey, child.paramChildIndex, - child.path, ) result.p.updateEdge(n) return true @@ -490,20 +484,18 @@ func (t *Tree) remove(method, path string) bool { mergedPath := fmt.Sprintf("%s%s", result.p.key, child.key) parent = newNodeFromRef( mergedPath, - child.handler, + child.route, child.children, child.childKeys, child.catchAllKey, child.paramChildIndex, - child.path, ) } else { parent = newNode( result.p.key, - result.p.handler, + result.p.route, parentEdges, result.p.catchAllKey, - result.p.path, ) } @@ -932,3 +924,22 @@ func (t *Tree) updateMaxDepth(max uint32) { t.maxDepth.Store(max) } } + +// newRoute create a new route, apply path options and apply middleware on the handler. +func (t *Tree) newRoute(path string, handler HandlerFunc, opts ...PathOption) *Route { + rte := &Route{ + ipStrategy: t.fox.ipStrategy, + base: handler, + path: path, + mws: t.mws, + redirectTrailingSlash: t.fox.redirectTrailingSlash, + ignoreTrailingSlash: t.fox.ignoreTrailingSlash, + } + + for _, opt := range opts { + opt.applyPath(rte) + } + rte.handler = applyMiddleware(RouteHandlers, rte.mws, handler) + + return rte +}