From 66f5f520a3d1d83329d59c4e5c88a811771bdf3c Mon Sep 17 00:00:00 2001 From: tigerwill90 Date: Sun, 17 Nov 2024 22:47:00 +0100 Subject: [PATCH] feat(txn): delete Swap from the API. Use transaction instead --- fox.go | 134 +++++++++++++++++++++++----------------------------- fox_test.go | 68 +++++++++----------------- iter.go | 8 ++-- tree.go | 76 ++++++++++++++++++----------- txn.go | 7 ++- txn_test.go | 54 ++++++++++----------- 6 files changed, 164 insertions(+), 183 deletions(-) diff --git a/fox.go b/fox.go index eee9da5..be1d220 100644 --- a/fox.go +++ b/fox.go @@ -14,7 +14,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "unicode/utf8" ) @@ -93,7 +92,7 @@ type Router struct { noMethod HandlerFunc tsrRedirect HandlerFunc autoOptions HandlerFunc - tree atomic.Pointer[Tree] + tree *Tree ipStrategy ClientIPStrategy mws []middleware handleMethodNotAllowed bool @@ -128,7 +127,7 @@ func New(opts ...GlobalOption) *Router { r.tsrRedirect = applyMiddleware(RedirectHandler, r.mws, defaultRedirectTrailingSlashHandler) r.autoOptions = applyMiddleware(OptionsHandler, r.mws, r.autoOptions) - r.tree.Store(r.NewTree()) + r.tree = r.newTree() return r } @@ -167,61 +166,6 @@ func (fox *Router) ClientIPStrategyEnabled() bool { return !ok } -// NewTree returns a fresh routing [Tree] that inherits all registered router options. It's safe to create multiple [Tree] -// concurrently. However, a Tree itself is not thread-safe and all its APIs that perform write operations should be run -// serially. Note that a [Tree] give direct access to the underlying [sync.Mutex]. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) NewTree() *Tree { - tree := new(Tree) - tree.fox = fox - - // Pre instantiate nodes for common http verb - nds := make([]*node, len(commonVerbs)) - for i := range commonVerbs { - nds[i] = new(node) - nds[i].key = commonVerbs[i] - nds[i].paramChildIndex = -1 - nds[i].wildcardChildIndex = -1 - } - tree.nodes.Store(&nds) - - tree.ctx = sync.Pool{ - New: func() any { - return tree.allocateContext() - }, - } - - return tree -} - -func (fox *Router) Txn() *Txn { - tree := fox.Tree() - tree.Lock() - root := tree.snapshot() - root.txn = true - return &Txn{ - snap: root, - main: tree, - } -} - -// Tree atomically loads and return the currently in-use routing tree. -// This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Tree() *Tree { - return fox.tree.Load() -} - -// Swap atomically replaces the currently in-use routing tree with the provided new tree, and returns the previous tree. -// Note that the swap will panic if the current tree belongs to a different instance of the router, preventing accidental -// replacement of trees from different routers. -func (fox *Router) Swap(new *Tree) (old *Tree) { - current := fox.tree.Load() - if current.fox != new.fox { - panic("swap failed: current and new routing trees belong to different router instances") - } - return fox.tree.Swap(new) -} - // Handle registers a new handler for the given method and route pattern. On success, it returns the newly registered [Route]. // If an error occurs, it returns one of the following: // - [ErrRouteExist]: If the route is already registered. @@ -231,10 +175,9 @@ func (fox *Router) Swap(new *Tree) (old *Tree) { // It's safe to add a new handler while the tree is in use for serving requests. This function is safe for concurrent // use by multiple goroutine. To override an existing route, use [Router.Update]. func (fox *Router) Handle(method, pattern string, handler HandlerFunc, opts ...RouteOption) (*Route, error) { - t := fox.Tree() - t.Lock() - defer t.Unlock() - return t.Handle(method, pattern, handler, opts...) + fox.tree.Lock() + defer fox.tree.Unlock() + return fox.tree.Handle(method, pattern, handler, opts...) } // MustHandle registers a new handler for the given method and route pattern. On success, it returns the newly registered [Route] @@ -257,10 +200,9 @@ func (fox *Router) MustHandle(method, pattern string, handler HandlerFunc, opts // It's safe to update a handler while the tree is in use for serving requests. This function is safe for concurrent // use by multiple goroutine. To add new handler, use [Router.Handle] method. func (fox *Router) Update(method, pattern string, handler HandlerFunc, opts ...RouteOption) (*Route, error) { - t := fox.Tree() - t.Lock() - defer t.Unlock() - return t.Update(method, pattern, handler, opts...) + fox.tree.Lock() + defer fox.tree.Unlock() + return fox.tree.Update(method, pattern, handler, opts...) } // Delete deletes an existing handler for the given method and route pattern. If an error occurs, it returns one of the following: @@ -270,10 +212,9 @@ func (fox *Router) Update(method, pattern string, handler HandlerFunc, opts ...R // It's safe to delete a handler while the tree is in use for serving requests. This function is safe for concurrent // use by multiple goroutine. func (fox *Router) Delete(method, pattern string) error { - t := fox.Tree() - t.Lock() - defer t.Unlock() - return t.Delete(method, pattern) + fox.tree.Lock() + defer fox.tree.Unlock() + return fox.tree.Delete(method, pattern) } // Lookup performs a manual route lookup for a given [http.Request], returning the matched [Route] along with a @@ -283,8 +224,7 @@ func (fox *Router) Delete(method, pattern string) error { // concurrent use by multiple goroutine and while mutation on [Tree] are ongoing. See also [Tree.Reverse] as an alternative. // This API is EXPERIMENTAL and is likely to change in future release. func (fox *Router) Lookup(w ResponseWriter, r *http.Request) (route *Route, cc ContextCloser, tsr bool) { - tree := fox.Tree() - return tree.Lookup(w, r) + return fox.tree.Lookup(w, r) } // Updates executes a function within the context of a read-write managed transaction. If no error is returned from the @@ -312,8 +252,50 @@ func (fox *Router) Updates(fn func(txn *Txn) error) error { // This function is safe for concurrent use by multiple goroutines and can operate while the [Tree] is being modified. // This API is EXPERIMENTAL and may change in future releases. func (fox *Router) Iter() Iter { - tree := fox.Tree() - return tree.Iter() + return fox.tree.Iter() +} + +// Txn create a new read-write transaction. +// It's safe to create and execute a transaction while the tree is in use for serving requests. +// This function is safe for concurrent use by multiple goroutine. For managed transaction, use [Router.Updates]. +func (fox *Router) Txn() *Txn { + fox.tree.Lock() + root := fox.tree.snapshot() + root.txn = true + return &Txn{ + snap: root, + main: fox.tree, + } +} + +// Tree atomically loads and return the currently in-use routing tree. +// This API is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) Tree() *Tree { + return fox.tree +} + +// newTree returns a fresh routing Tree that inherits all registered router options. +func (fox *Router) newTree() *Tree { + tree := new(Tree) + tree.fox = fox + + // Pre instantiate nodes for common http verb + nds := make([]*node, len(commonVerbs)) + for i := range commonVerbs { + nds[i] = new(node) + nds[i].key = commonVerbs[i] + nds[i].paramChildIndex = -1 + nds[i].wildcardChildIndex = -1 + } + tree.root.Store(&nds) + + tree.ctx = sync.Pool{ + New: func() any { + return tree.allocateContext() + }, + } + + return tree } // DefaultNotFoundHandler is a simple [HandlerFunc] that replies to each request @@ -371,11 +353,11 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { path = r.URL.RawPath } - tree := fox.tree.Load() + tree := fox.tree c := tree.ctx.Get().(*cTx) c.reset(w, r) - nds := *tree.nodes.Load() + nds := *tree.root.Load() index := findRootNode(r.Method, nds) if index < 0 || len(nds[index].children) == 0 { goto NoMethodFallback diff --git a/fox_test.go b/fox_test.go index 23c8af1..17f4ab1 100644 --- a/fox_test.go +++ b/fox_test.go @@ -950,7 +950,7 @@ func TestEmptyCatchAll(t *testing.T) { for _, rte := range tc.routes { require.NoError(t, onlyError(tree.Handle(http.MethodGet, rte, emptyHandler))) } - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), tc.path, c, false) require.False(t, tsr) @@ -982,7 +982,7 @@ func TestRouteWithParams(t *testing.T) { require.NoError(t, onlyError(tree.Handle(http.MethodGet, rte, emptyHandler))) } - nds := *tree.nodes.Load() + nds := *tree.root.Load() for _, rte := range routes { c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), rte, c, false) @@ -1023,7 +1023,7 @@ func TestRouteParamEmptySegment(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), tc.path, c, false) assert.Nil(t, n) @@ -1490,11 +1490,11 @@ func TestOverlappingRoute(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - tree := r.NewTree() + tree := r.newTree() for _, rte := range tc.routes { require.NoError(t, onlyError(tree.Handle(http.MethodGet, rte, emptyHandler))) } - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), tc.path, c, false) @@ -2244,7 +2244,7 @@ func TestInfixWildcard(t *testing.T) { for _, rte := range tc.routes { require.NoError(t, onlyError(tree.Handle(http.MethodGet, rte, emptyHandler))) } - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), tc.path, c, false) require.NotNil(t, n) @@ -2496,7 +2496,7 @@ func TestDomainLookup(t *testing.T) { for _, rte := range tc.routes { require.NoError(t, onlyError(tree.Handle(http.MethodGet, rte, emptyHandler))) } - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], tc.host, tc.path, c, false) require.NotNil(t, n) @@ -2817,7 +2817,7 @@ func TestInfixWildcardTsr(t *testing.T) { for _, rte := range tc.routes { require.NoError(t, onlyError(tree.Handle(http.MethodGet, rte, emptyHandler))) } - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), tc.path, c, false) require.NotNil(t, n) @@ -2970,7 +2970,7 @@ func TestInsertUpdateAndDeleteWithHostname(t *testing.T) { assert.Equal(t, []Annotation{updateAnnot2}, slices.Collect(r.Annotations())) } } - nds := *tree.nodes.Load() + nds := *tree.root.Load() assert.Equal(t, http.MethodGet, nds[0].key) assert.Len(t, nds[0].children, 0) @@ -2992,7 +2992,7 @@ func TestInsertUpdateAndDeleteWithHostname(t *testing.T) { } } - nds = *tree.nodes.Load() + nds = *tree.root.Load() assert.Equal(t, http.MethodGet, nds[0].key) assert.Len(t, nds[0].children, 0) }) @@ -3783,7 +3783,7 @@ func TestTree_LookupTsr(t *testing.T) { for _, path := range tc.paths { require.NoError(t, tree.insert(http.MethodGet, tree.newRoute(path, emptyHandler), 0)) } - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, got := tree.lookupByPath(nds[0].children[0].Load(), tc.key, c, true) assert.Equal(t, tc.want, got) @@ -4262,17 +4262,17 @@ func TestTree_Delete(t *testing.T) { cnt := len(slices.Collect(iterutil.Right(it.All()))) assert.Equal(t, 0, cnt) - assert.Equal(t, 4, len(*tree.nodes.Load())) + assert.Equal(t, 4, len(*tree.root.Load())) } func TestTree_DeleteRoot(t *testing.T) { tree := New().Tree() require.NoError(t, onlyError(tree.Handle(http.MethodOptions, "/foo/bar", emptyHandler))) require.NoError(t, tree.Delete(http.MethodOptions, "/foo/bar")) - assert.Equal(t, 4, len(*tree.nodes.Load())) + assert.Equal(t, 4, len(*tree.root.Load())) require.NoError(t, onlyError(tree.Handle(http.MethodOptions, "exemple.com/foo/bar", emptyHandler))) require.NoError(t, tree.Delete(http.MethodOptions, "exemple.com/foo/bar")) - assert.Equal(t, 4, len(*tree.nodes.Load())) + assert.Equal(t, 4, len(*tree.root.Load())) } func TestTree_DeleteWildcard(t *testing.T) { @@ -4467,7 +4467,6 @@ func TestRouterWithMethodNotAllowedHandler(t *testing.T) { } func TestRouterWithAutomaticOptions(t *testing.T) { - f := New(WithAutoOptions(true)) cases := []struct { name string @@ -4523,9 +4522,10 @@ func TestRouterWithAutomaticOptions(t *testing.T) { }, } - require.True(t, f.AutoOptionsEnabled()) for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { + f := New(WithAutoOptions(true)) + require.True(t, f.AutoOptionsEnabled()) for _, method := range tc.methods { require.NoError(t, onlyError(f.Tree().Handle(method, tc.path, func(c Context) { c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().Host, c.Request().URL.Path))), ", ")) @@ -4537,15 +4537,11 @@ func TestRouterWithAutomaticOptions(t *testing.T) { f.ServeHTTP(w, req) assert.Equal(t, tc.wantCode, w.Code) assert.Equal(t, tc.want, w.Header().Get("Allow")) - // Reset - f.Swap(f.NewTree()) }) } } func TestRouterWithAutomaticOptionsAndIgnoreTsOptionEnable(t *testing.T) { - f := New(WithAutoOptions(true), WithIgnoreTrailingSlash(true)) - cases := []struct { name string target string @@ -4602,6 +4598,7 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionEnable(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { + f := New(WithAutoOptions(true), WithIgnoreTrailingSlash(true)) for _, method := range tc.methods { require.NoError(t, onlyError(f.Tree().Handle(method, tc.path, func(c Context) { c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().Host, c.Request().URL.Path))), ", ")) @@ -4613,15 +4610,11 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionEnable(t *testing.T) { f.ServeHTTP(w, req) assert.Equal(t, tc.wantCode, w.Code) assert.Equal(t, tc.want, w.Header().Get("Allow")) - // Reset - f.Swap(f.NewTree()) }) } } func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) { - f := New(WithAutoOptions(true)) - cases := []struct { name string target string @@ -4647,6 +4640,7 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { + f := New(WithAutoOptions(true)) for _, method := range tc.methods { require.NoError(t, onlyError(f.Tree().Handle(method, tc.path, func(c Context) { c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().Host, c.Request().URL.Path))), ", ")) @@ -4657,8 +4651,6 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) { w := httptest.NewRecorder() f.ServeHTTP(w, req) assert.Equal(t, tc.wantCode, w.Code) - // Reset - f.Swap(f.NewTree()) }) } } @@ -5094,20 +5086,6 @@ func TestEncodedPath(t *testing.T) { assert.Equal(t, encodedPath, w.Body.String()) } -func TestTreeSwap(t *testing.T) { - f := New() - tree := f.NewTree() - assert.NotPanics(t, func() { - f.Swap(tree) - }) - assert.Equal(t, tree, f.Tree()) - - f2 := New() - assert.Panics(t, func() { - f2.Swap(tree) - }) -} - func TestFuzzInsertLookupParam(t *testing.T) { // no '*', '{}' and '/' and invalid escape char unicodeRanges := fuzz.UnicodeRanges{ @@ -5134,7 +5112,7 @@ func TestFuzzInsertLookupParam(t *testing.T) { } path := fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3) if err := tree.insert(http.MethodGet, tree.newRoute(path, emptyHandler), 3); err == nil { - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c, false) @@ -5192,7 +5170,7 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { assert.Equal(t, len(routes), countPath) for rte := range routes { - nds := *tree.nodes.Load() + nds := *tree.root.Load() c := newTestContextTree(tree) n, tsr := tree.lookupByPath(nds[0].children[0].Load(), "/"+rte, c, true) require.NotNilf(t, n, "route /%s", rte) @@ -5266,7 +5244,7 @@ func TestRaceHostnamePathSwitch(t *testing.T) { start() wg.Wait() - nds := *tree.nodes.Load() + nds := *tree.root.Load() require.Len(t, nds[0].children, 1) } @@ -5460,7 +5438,7 @@ func TestNode_String(t *testing.T) { f := New() require.NoError(t, onlyError(f.Handle(http.MethodGet, "/foo/{bar}/*{baz}", emptyHandler))) tree := f.Tree() - nds := *tree.nodes.Load() + nds := *tree.root.Load() want := `path: GET path: /foo/{bar}/*{baz} [leaf=/foo/{bar}/*{baz}] [bar (10), baz (-1)]` @@ -5471,7 +5449,7 @@ func TestNode_Debug(t *testing.T) { f := New() require.NoError(t, onlyError(f.Handle(http.MethodGet, "/foo/*{any}/bar", emptyHandler))) tree := f.Tree() - nds := *tree.nodes.Load() + nds := *tree.root.Load() want := `path: GET path: /foo/*{any}/bar [leaf=/foo/*{any}/bar] [any (11)] diff --git a/iter.go b/iter.go index a21bae5..68b1838 100644 --- a/iter.go +++ b/iter.go @@ -64,7 +64,7 @@ type Iter struct { // This API is EXPERIMENTAL and is likely to change in future release. func (it Iter) Methods() iter.Seq[string] { return func(yield func(string) bool) { - nds := *it.t.nodes.Load() + nds := *it.t.root.Load() for i := range nds { if len(nds[i].children) > 0 { if !yield(nds[i].key) { @@ -85,7 +85,7 @@ func (it Iter) Methods() iter.Seq[string] { // This API is EXPERIMENTAL and is likely to change in future release. func (it Iter) Routes(methods iter.Seq[string], pattern string) iter.Seq2[string, *Route] { return func(yield func(string, *Route) bool) { - nds := *it.t.nodes.Load() + nds := *it.t.root.Load() c := it.t.ctx.Get().(*cTx) defer c.Close() for method := range methods { @@ -119,7 +119,7 @@ func (it Iter) Routes(methods iter.Seq[string], pattern string) iter.Seq2[string // This API is EXPERIMENTAL and may change in future releases. func (it Iter) Reverse(methods iter.Seq[string], host, path string) iter.Seq2[string, *Route] { return func(yield func(string, *Route) bool) { - nds := *it.t.nodes.Load() + nds := *it.t.root.Load() c := it.t.ctx.Get().(*cTx) defer c.Close() for method := range methods { @@ -149,7 +149,7 @@ func (it Iter) Reverse(methods iter.Seq[string], host, path string) iter.Seq2[st // This API is EXPERIMENTAL and may change in future releases. func (it Iter) Prefix(methods iter.Seq[string], prefix string) iter.Seq2[string, *Route] { return func(yield func(string, *Route) bool) { - nds := *it.t.nodes.Load() + nds := *it.t.root.Load() maxDepth := it.t.maxDepth.Load() var stacks []stack if maxDepth < stackSizeThreshold { diff --git a/tree.go b/tree.go index 886fac0..eab0099 100644 --- a/tree.go +++ b/tree.go @@ -17,22 +17,9 @@ import ( // Tree implements a Concurrent Radix Tree that supports lock-free reads while allowing concurrent writes. // The caller is responsible for ensuring that all writes are run serially. -// -// IMPORTANT: -// Each tree as its own [sync.Mutex] that may be used to serialize write. Since the router tree may be swapped at any -// given time (see [Router.Swap]), you MUST always copy the pointer locally to avoid inadvertently causing a deadlock -// by locking/unlocking the wrong Tree. -// -// - Good -// t := fox.Tree() -// t.Lock() -// defer t.Unlock() -// - Dramatically bad, may cause deadlock -// fox.Tree().Lock() -// defer fox.Tree().Unlock() type Tree struct { ctx sync.Pool - nodes atomic.Pointer[[]*node] + root atomic.Pointer[[]*node] fox *Router writable *simplelru.LRU[*node, any] sync.Mutex @@ -136,7 +123,7 @@ func (t *Tree) Has(method, pattern string) bool { // mutation on [Tree] are ongoing. See also [Tree.Has] as an alternative. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Route(method, pattern string) *Route { - nds := *t.nodes.Load() + nds := *t.root.Load() index := findRootNode(method, nds) if index < 0 || len(nds[index].children) == 0 { return nil @@ -160,7 +147,7 @@ func (t *Tree) Route(method, pattern string) *Route { // mutation on [Tree] are ongoing. See also [Tree.Lookup] as an alternative. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Reverse(method, host, path string) (route *Route, tsr bool) { - nds := *t.nodes.Load() + nds := *t.root.Load() index := findRootNode(method, nds) if index < 0 || len(nds[index].children) == 0 { return nil, false @@ -183,7 +170,7 @@ func (t *Tree) Reverse(method, host, path string) (route *Route, tsr bool) { // concurrent use by multiple goroutine and while mutation on [Tree] are ongoing. See also [Tree.Reverse] as an alternative. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (route *Route, cc ContextCloser, tsr bool) { - nds := *t.nodes.Load() + nds := *t.root.Load() index := findRootNode(r.Method, nds) if index < 0 || len(nds[index].children) == 0 { return @@ -233,7 +220,7 @@ func (t *Tree) insert(method string, route *Route, paramsN uint32) error { } var rootNode *node - nds := *t.nodes.Load() + nds := *t.root.Load() index := findRootNode(method, nds) if index < 0 { rootNode = &node{ @@ -490,7 +477,7 @@ func (t *Tree) update(method string, route *Route) error { path := route.pattern - nds := *t.nodes.Load() + nds := *t.root.Load() index := findRootNode(method, nds) if index < 0 { return fmt.Errorf("%w: route %s %s is not registered", ErrRouteNotFound, method, path) @@ -537,7 +524,7 @@ func (t *Tree) remove(method, path string) bool { t.writable = lru } - nds := *t.nodes.Load() + nds := *t.root.Load() index := findRootNode(method, nds) if index < 0 { return false @@ -1036,7 +1023,7 @@ Walk: // We can record params here because it may be either an ending catch-all node (leaf=/foo/*{args}) with // children, or we may have a tsr opportunity (leaf=/foo/*{args}/ with /foo/x/y/z path). Note that if - // there is no tsr opportunity, and skipped nodes > 0, we will truncate the params anyway. + // there is no tsr opportunity, and skipped nodes > 0, we will truncateRoot the params anyway. if !lazy { *c.params = append(*c.params, Param{Key: current.params[paramKeyCnt].key, Value: path[startPath:]}) } @@ -1294,7 +1281,7 @@ func (t *Tree) allocateContext() *cTx { func (t *Tree) snapshot() *Tree { tree := new(Tree) tree.fox = t.fox - tree.nodes.Store(t.nodes.Load()) + tree.root.Store(t.root.Load()) tree.ctx = sync.Pool{ New: func() any { return tree.allocateContext() @@ -1308,11 +1295,11 @@ func (t *Tree) snapshot() *Tree { // addRoot append a new root node to the tree. // Note: This function should be guarded by mutex. func (t *Tree) addRoot(n *node) { - nds := *t.nodes.Load() + nds := *t.root.Load() newNds := make([]*node, 0, len(nds)+1) newNds = append(newNds, nds...) newNds = append(newNds, n) - t.nodes.Store(&newNds) + t.root.Store(&newNds) } // updateRoot replaces a root node in the tree. @@ -1322,7 +1309,7 @@ func (t *Tree) addRoot(n *node) { // updated root node, and the entire list is swapped afterwards. // Note: This function should be guarded by mutex. func (t *Tree) updateRoot(n *node) bool { - nds := *t.nodes.Load() + nds := *t.root.Load() // for root node, the key contains the HTTP verb. index := findRootNode(n.key, nds) if index < 0 { @@ -1332,14 +1319,14 @@ func (t *Tree) updateRoot(n *node) bool { newNds = append(newNds, nds[:index]...) newNds = append(newNds, n) newNds = append(newNds, nds[index+1:]...) - t.nodes.Store(&newNds) + t.root.Store(&newNds) return true } // removeRoot remove a root nod from the tree. // Note: This function should be guarded by mutex. func (t *Tree) removeRoot(method string) bool { - nds := *t.nodes.Load() + nds := *t.root.Load() index := findRootNode(method, nds) if index < 0 { return false @@ -1347,10 +1334,43 @@ func (t *Tree) removeRoot(method string) bool { newNds := make([]*node, 0, len(nds)-1) newNds = append(newNds, nds[:index]...) newNds = append(newNds, nds[index+1:]...) - t.nodes.Store(&newNds) + t.root.Store(&newNds) return true } +// truncateRoot truncate the tree for the provided methods. +// Note: This function should be guarded by mutex. +func (t *Tree) truncateRoot(methods []string) { + nds := *t.root.Load() + oldlen := len(nds) + newNds := make([]*node, len(nds)) + copy(newNds, nds) + + for _, method := range methods { + idx := findRootNode(method, newNds) + if idx < 0 { + continue + } + if !isRemovable(method) { + newNds[idx] = new(node) + newNds[idx].key = commonVerbs[idx] + newNds[idx].paramChildIndex = -1 + newNds[idx].wildcardChildIndex = -1 + continue + } + + newNds = append(newNds[:idx], newNds[idx+1:]...) + } + + clear(newNds[len(newNds):oldlen]) // zero/nil out the obsolete elements, for GC + + // Update the tree's nodes with the new slice. + t.root.Store(&newNds) +} + +// updateToRoot propagate update to the root by cloning any visited node that have not been cloned previously. +// This effectively allow to create a fully isolated snapshot of the tree. +// Note: This function should be guarded by mutex. func (t *Tree) updateToRoot(p, pp, ppp *node, visited []*node, n *node) { nn := n if p != nil { diff --git a/txn.go b/txn.go index 193b0cf..acbe9f7 100644 --- a/txn.go +++ b/txn.go @@ -7,6 +7,7 @@ import ( const defaultModifiedCache = 8192 +// Txn is a read-write transaction against the [Router]. type Txn struct { snap *Tree main *Tree @@ -41,6 +42,10 @@ func (txn *Txn) Lookup(w ResponseWriter, r *http.Request) (route *Route, cc Cont return txn.snap.Lookup(w, r) } +func (txn *Txn) Truncate(methods ...string) { + txn.snap.truncateRoot(methods) +} + func (txn *Txn) Iter() Iter { return Iter{t: txn.snap} } @@ -51,7 +56,7 @@ func (txn *Txn) Commit() { txn.snap.writable = nil txn.main.maxParams.Store(txn.snap.maxParams.Load()) txn.main.maxDepth.Store(txn.snap.maxDepth.Load()) - txn.main.nodes.Store(txn.snap.nodes.Load()) + txn.main.root.Store(txn.snap.root.Load()) txn.main.Unlock() }) } diff --git a/txn_test.go b/txn_test.go index 622782e..bb9a51f 100644 --- a/txn_test.go +++ b/txn_test.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/stretchr/testify/require" "net/http" + "slices" "testing" ) @@ -30,12 +31,33 @@ func TestXyz(t *testing.T) { require.NoError(t, txn.Delete(http.MethodGet, "a.b.c/a")) fmt.Println(txn.Has(http.MethodGet, "/a/b/c/d")) - fmt.Println("current", (*tree.nodes.Load())[0]) + fmt.Println("current", (*tree.root.Load())[0]) - fmt.Println("isolated", (*txn.snap.nodes.Load())[0]) + fmt.Println("isolated", (*txn.snap.root.Load())[0]) txn.Commit() - fmt.Println("committed", (*tree.nodes.Load())[0]) + fmt.Println("committed", (*tree.root.Load())[0]) + +} + +func TestZ(t *testing.T) { + f := New() + f.MustHandle(http.MethodGet, "a.b.c/a", emptyHandler) + tree := f.Tree() + txn := f.Txn() + defer txn.Rollback() + + require.NoError(t, onlyError(txn.Handle(http.MethodGet, "/a/b", emptyHandler))) + require.NoError(t, onlyError(txn.Handle(http.MethodPost, "/a/b/c", emptyHandler))) + require.NoError(t, onlyError(txn.Handle(http.MethodConnect, "/a/b/c/d", emptyHandler))) + require.NoError(t, onlyError(txn.Handle(http.MethodOptions, "/a/b/c/d/e", emptyHandler))) + require.NoError(t, onlyError(txn.Handle(http.MethodTrace, "/a/b/c/d/e/f", emptyHandler))) + txn.Truncate(slices.Collect(txn.Iter().Methods())...) + for _, n := range *txn.snap.root.Load() { + fmt.Println("isolated", n) + } + + fmt.Println("current", (*tree.root.Load())[0]) } @@ -59,29 +81,3 @@ func BenchmarkTx(b *testing.B) { txn.Commit() } } - -func BenchmarkNonTx(b *testing.B) { - f := New() - - for _, route := range staticRoutes { - f.MustHandle(http.MethodGet, route.path, emptyHandler) - } - - b.ResetTimer() - b.ReportAllocs() - for range b.N { - old := f.Tree() - new := f.NewTree() - for method, route := range old.Iter().All() { - new.Handle(method, route.Pattern(), emptyHandler) - } - - new.Delete(http.MethodGet, "/go1compat.html") - new.Delete(http.MethodGet, "/articles/wiki/part1-noerror.go") - new.Delete(http.MethodGet, "/gopher/gophercolor16x16.png") - new.Handle(http.MethodGet, "/go1compat.html", emptyHandler) - new.Handle(http.MethodGet, "/articles/wiki/part1-noerror.go", emptyHandler) - new.Handle(http.MethodGet, "/gopher/gophercolor16x16.png", emptyHandler) - f.Swap(new) - } -}