diff --git a/atreugo.go b/atreugo.go index 4c24e04..6e61009 100644 --- a/atreugo.go +++ b/atreugo.go @@ -195,6 +195,8 @@ func (s *Atreugo) HandleOPTIONS(v bool) { // // ServeConn closes c before returning. func (s *Atreugo) ServeConn(c net.Conn) error { + s.init() + return s.server.ServeConn(c) } diff --git a/router.go b/router.go index 2ef9447..290bb84 100644 --- a/router.go +++ b/router.go @@ -110,6 +110,8 @@ func (r *Router) init() { optionsURLsHandled = append(optionsURLsHandled, p.url) } } + + r.paths = nil } func (r *Router) buildMiddlewaresChain(skip ...Middleware) Middlewares { diff --git a/router_test.go b/router_test.go index c7cc115..6a6a607 100644 --- a/router_test.go +++ b/router_test.go @@ -193,8 +193,27 @@ func TestRouter_init(t *testing.T) { // nolint:funlen r := newRouter(testLog, nil) r.appendPath(path) r.appendPath(pathOptions) + + totalRegisteredViews := len(r.paths) + 1 // Add +1 for auto OPTIONS handle + r.init() + // Check if a re-execution raise a panic + func() { + defer func() { + err := recover() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }() + + r.init() + }() + + if len(registeredViews) != totalRegisteredViews { + t.Fatalf("Registered views == %d, want %d", len(registeredViews), totalRegisteredViews) + } + ctx := new(fasthttp.RequestCtx) h, _ := r.router.Lookup(path.method, path.url, ctx) @@ -218,11 +237,6 @@ func TestRouter_init(t *testing.T) { // nolint:funlen } } - totalRegisteredViews := len(r.paths) + 1 // Add +1 for auto OPTIONS handle - if len(registeredViews) != totalRegisteredViews { - t.Fatalf("Registered views == %d, want %d", len(registeredViews), totalRegisteredViews) - } - if reflect.ValueOf(registeredViews[0]).Pointer() != reflect.ValueOf(path.view).Pointer() { t.Errorf("Registered view == %p, want %p", registeredViews[0], path.view) } @@ -889,15 +903,15 @@ func TestRouter_Path_Shortcuts(t *testing.T) { //nolint:funlen } for _, test := range tests { - tt := test + test.args.fn(path, viewFn) + } - t.Run(tt.name, func(t *testing.T) { - r.paths = r.paths[:0] - r.router = fastrouter.New() + r.init() - tt.args.fn(path, viewFn) - r.init() + for _, test := range tests { + tt := test + t.Run(tt.name, func(t *testing.T) { reqMethod := tt.args.method for reqMethod == fastrouter.MethodWild { reqMethod = randomHTTPMethod()