Skip to content

Commit

Permalink
feat: wip per route options
Browse files Browse the repository at this point in the history
  • Loading branch information
tigerwill90 committed Oct 5, 2024
1 parent 74dab72 commit 36e451a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 73 deletions.
22 changes: 14 additions & 8 deletions fox.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,21 @@ func (f ClientIPStrategyFunc) ClientIP(c Context) (*net.IPAddr, error) {
// Route represent a registered route in the route tree.
type Route struct {
ipStrategy ClientIPStrategy
base HandlerFunc
handler HandlerFunc
path string
mws []middleware
redirectTrailingSlash bool
ignoreTrailingSlash bool
}

// Handle call the registered handler with the provided Context.
// Handle calls the base handler with the provided Context.
func (r *Route) Handle(c Context) {
r.base(c)

Check warning on line 83 in fox.go

View check run for this annotation

Codecov / codecov/patch

fox.go#L82-L83

Added lines #L82 - L83 were not covered by tests
}

// HandleWithMiddleware calls the handler with applied middleware using the provided Context.
func (r *Route) HandleWithMiddleware(c Context) {
r.handler(c)

Check warning on line 88 in fox.go

View check run for this annotation

Codecov / codecov/patch

fox.go#L87-L88

Added lines #L87 - L88 were not covered by tests
}

Expand Down Expand Up @@ -298,7 +304,7 @@ Next:
method := nds[i].key
it := newRawIterator(nds[i])
for it.hasNext() {
err := fn(method, it.path, it.current.handler)
err := fn(method, it.path, it.current.route.handler)
if err != nil {
if errors.Is(err, SkipMethod) {
continue Next
Expand Down Expand Up @@ -384,9 +390,9 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {

n, tsr = tree.lookup(nds[index], target, c, false)
if !tsr && n != nil {
c.path = n.path
c.path = n.route.path
c.tsr = tsr
n.handler(c)
n.route.handler(c)
// Put back the context, if not extended more than max params or max depth, allowing
// the slice to naturally grow within the constraint.
if cap(*c.params) <= int(tree.maxParams.Load()) && cap(*c.skipNds) <= int(tree.maxDepth.Load()) {
Expand All @@ -397,9 +403,9 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {

if r.Method != http.MethodConnect && r.URL.Path != "/" && tsr {
if fox.ignoreTrailingSlash {
c.path = n.path
c.path = n.route.path
c.tsr = tsr
n.handler(c)
n.route.handler(c)
c.Close()
return
}
Expand Down Expand Up @@ -645,7 +651,7 @@ func getRouteConflict(n *node) []string {
routes := make([]string, 0)

if n.isCatchAll() {
routes = append(routes, n.path)
routes = append(routes, n.route.path)
return routes
}

Expand All @@ -654,7 +660,7 @@ func getRouteConflict(n *node) []string {
}
it := newRawIterator(n)
for it.hasNext() {
routes = append(routes, it.current.path)
routes = append(routes, it.current.route.path)
}
return routes
}
Expand Down
34 changes: 20 additions & 14 deletions fox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,9 @@ func TestRouteWithParams(t *testing.T) {
c := newTestContextTree(tree)
n, tsr := tree.lookup(nds[0], rte, c, false)
require.NotNil(t, n)
require.NotNil(t, n.route)
assert.False(t, tsr)
assert.Equal(t, rte, n.path)
assert.Equal(t, rte, n.route.path)
}
}

Expand Down Expand Up @@ -1196,9 +1197,9 @@ func TestOverlappingRoute(t *testing.T) {
c := newTestContextTree(tree)
n, tsr := tree.lookup(nds[0], tc.path, c, false)
require.NotNil(t, n)
require.NotNil(t, n.handler)
require.NotNil(t, n.route)
assert.False(t, tsr)
assert.Equal(t, tc.wantMatch, n.path)
assert.Equal(t, tc.wantMatch, n.route.path)
if len(tc.wantParams) == 0 {
assert.Empty(t, c.Params())
} else {
Expand All @@ -1209,10 +1210,10 @@ func TestOverlappingRoute(t *testing.T) {
c = newTestContextTree(tree)
n, tsr = tree.lookup(nds[0], tc.path, c, true)
require.NotNil(t, n)
require.NotNil(t, n.handler)
require.NotNil(t, n.route)
assert.False(t, tsr)
assert.Empty(t, c.Params())
assert.Equal(t, tc.wantMatch, n.path)
assert.Equal(t, tc.wantMatch, n.route.path)
})
}
}
Expand Down Expand Up @@ -1643,15 +1644,16 @@ func TestTree_LookupTsr(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
tree := New().Tree()
for _, path := range tc.paths {
require.NoError(t, tree.insert(http.MethodGet, path, "", 0, emptyHandler))
require.NoError(t, tree.insert(http.MethodGet, path, "", 0, tree.newRoute(path, emptyHandler)))
}
nds := *tree.nodes.Load()
c := newTestContextTree(tree)
n, got := tree.lookup(nds[0], tc.key, c, true)
assert.Equal(t, tc.want, got)
if tc.want {
require.NotNil(t, n)
assert.Equal(t, tc.wantPath, n.path)
require.NotNil(t, n.route)
assert.Equal(t, tc.wantPath, n.route.path)
}
})
}
Expand Down Expand Up @@ -2060,7 +2062,6 @@ func TestRouterWithTsrParams(t *testing.T) {
f := New(WithIgnoreTrailingSlash(true))
for _, rte := range tc.routes {
require.NoError(t, f.Handle(http.MethodGet, rte, func(c Context) {
fmt.Println(c.Path(), c.Params())
assert.Equal(t, tc.wantPath, c.Path())
assert.Equal(t, tc.wantParams, c.Params())
assert.Equal(t, tc.wantTsr, unwrapContext(t, c).tsr)
Expand Down Expand Up @@ -2804,14 +2805,16 @@ func TestFuzzInsertLookupParam(t *testing.T) {
if s1 == "" || s2 == "" || e1 == "" || e2 == "" || e3 == "" {
continue
}
if err := tree.insert(http.MethodGet, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), "", 3, emptyHandler); err == nil {
path := fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3)
if err := tree.insert(http.MethodGet, path, "", 3, tree.newRoute(path, emptyHandler)); err == nil {
nds := *tree.nodes.Load()

c := newTestContextTree(tree)
n, tsr := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c, false)
require.NotNil(t, n)
require.NotNil(t, n.route)
assert.False(t, tsr)
assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.path)
assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.route.path)
assert.Equal(t, "xxxx", c.Param(e1))
assert.Equal(t, "xxxx", c.Param(e2))
assert.Equal(t, "xxxx", c.Param(e3))
Expand All @@ -2833,7 +2836,7 @@ func TestFuzzInsertNoPanics(t *testing.T) {
continue
}
require.NotPanicsf(t, func() {
_ = tree.insert(http.MethodGet, rte, catchAllKey, 0, emptyHandler)
_ = tree.insert(http.MethodGet, rte, catchAllKey, 0, tree.newRoute(appendCatchAll(rte, catchAllKey), emptyHandler))
}, fmt.Sprintf("rte: %s, catch all: %s", rte, catchAllKey))
}
}
Expand All @@ -2854,7 +2857,8 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) {
f.Fuzz(&routes)

for rte := range routes {
err := tree.insert(http.MethodGet, "/"+rte, "", 0, emptyHandler)
path := "/" + rte
err := tree.insert(http.MethodGet, path, "", 0, tree.newRoute(path, emptyHandler))
require.NoError(t, err)
}

Expand All @@ -2870,10 +2874,12 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) {
c := newTestContextTree(tree)
n, tsr := tree.lookup(nds[0], "/"+rte, c, true)
require.NotNilf(t, n, "route /%s", rte)
require.NotNilf(t, n.route, "route /%s", rte)
require.Falsef(t, tsr, "tsr: %t", tsr)
require.Truef(t, n.isLeaf(), "route /%s", rte)
require.Equal(t, "/"+rte, n.path)
require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", emptyHandler))
require.Equal(t, "/"+rte, n.route.path)
path := "/" + rte
require.NoError(t, tree.update(http.MethodGet, path, "", tree.newRoute(path, emptyHandler)))
}

for rte := range routes {
Expand Down
6 changes: 3 additions & 3 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (it *Iterator) Next() {
// Path returns the registered path for the current route.
func (it *Iterator) Path() string {
if it.current != nil {
return it.current.path
return it.current.route.path
}
return ""
}
Expand All @@ -180,7 +180,7 @@ func (it *Iterator) Method() string {
// Handler return the registered handler for the current route.
func (it *Iterator) Handler() HandlerFunc {
if it.current != nil {
return it.current.handler
return it.current.route.handler
}
return nil
}
Expand Down Expand Up @@ -221,7 +221,7 @@ func (it *rawIterator) hasNext() bool {
it.current = elem

if it.current.isLeaf() {
it.path = elem.path
it.path = elem.route.Path()
return true
}
}
Expand Down
26 changes: 13 additions & 13 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
)

type node struct {
// The registered handler matching the full path. Nil if the node is not a leaf.
// Once assigned, handler is immutable.
handler HandlerFunc
// The registered route matching the full path. Nil if the node is not a leaf.
// Once assigned, route is immutable.
route *Route

// key represent a segment of a route which share a common prefix with it parent.
key string
Expand All @@ -23,9 +23,6 @@ type node struct {
// Once assigned, catchAllKey is immutable.
catchAllKey string

// The full path when it's a leaf node
path string

// First char of each outgoing edges from this node sorted in ascending order.
// Once assigned, this is a read only slice. It allows to lazily traverse the
// tree without the extra cost of atomic load operation.
Expand All @@ -42,7 +39,11 @@ type node struct {
paramChildIndex int
}

func newNode(key string, handler HandlerFunc, children []*node, catchAllKey string, path string) *node {
func newNode(key string, route *Route, children []*node, catchAllKey string) *node {
// TODO use this instead of old sort.Slice
/* slices.SortFunc(children, func(a, b *node) int {
return cmp.Compare(a.key, b.key)
})*/
sort.Slice(children, func(i, j int) bool {
return children[i].key < children[j].key
})
Expand All @@ -58,24 +59,23 @@ func newNode(key string, handler HandlerFunc, children []*node, catchAllKey stri
}
}

return newNodeFromRef(key, handler, nds, childKeys, catchAllKey, childIndex, path)
return newNodeFromRef(key, route, nds, childKeys, catchAllKey, childIndex)
}

func newNodeFromRef(key string, handler HandlerFunc, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int, path string) *node {
func newNodeFromRef(key string, route *Route, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int) *node {
return &node{
key: key,
childKeys: childKeys,
children: children,
handler: handler,
route: route,
catchAllKey: catchAllKey,
path: appendCatchAll(path, catchAllKey),
paramChildIndex: childIndex,
params: parseWildcard(key),
}
}

func (n *node) isLeaf() bool {
return n.handler != nil
return n.route != nil
}

func (n *node) isCatchAll() bool {
Expand Down Expand Up @@ -211,7 +211,7 @@ func (n *node) string(space int) string {
}
if n.isLeaf() {
sb.WriteString(" [leaf=")
sb.WriteString(n.path)
sb.WriteString(n.route.path)
sb.WriteString("]")
}
if n.hasWildcard() {
Expand Down
Loading

0 comments on commit 36e451a

Please sign in to comment.