From 20e20bd58aefbee94f5f46c8c511d6459800ab05 Mon Sep 17 00:00:00 2001 From: tigerwill90 Date: Sat, 5 Oct 2024 23:29:56 +0200 Subject: [PATCH] feat: improve test coverage --- fox.go | 30 +++++++----------- fox_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++++- response_writer.go | 7 ++-- tree.go | 20 ++++++++---- 4 files changed, 108 insertions(+), 28 deletions(-) diff --git a/fox.go b/fox.go index 5734b61..1821d8f 100644 --- a/fox.go +++ b/fox.go @@ -141,9 +141,9 @@ var _ http.Handler = (*Router)(nil) func New(opts ...GlobalOption) *Router { r := new(Router) - r.noRoute = DefaultNotFoundHandler() - r.noMethod = DefaultMethodNotAllowedHandler() - r.autoOptions = DefaultOptionsHandler() + r.noRoute = DefaultNotFoundHandler + r.noMethod = DefaultMethodNotAllowedHandler + r.autoOptions = DefaultOptionsHandler r.ipStrategy = noClientIPStrategy{} for _, opt := range opts { @@ -322,27 +322,21 @@ Next: return nil } -// DefaultNotFoundHandler returns a simple HandlerFunc that replies to each request +// DefaultNotFoundHandler is a simple HandlerFunc that replies to each request // with a “404 page not found” reply. -func DefaultNotFoundHandler() HandlerFunc { - return func(c Context) { - http.Error(c.Writer(), "404 page not found", http.StatusNotFound) - } +func DefaultNotFoundHandler(c Context) { + http.Error(c.Writer(), "404 page not found", http.StatusNotFound) } -// DefaultMethodNotAllowedHandler returns a simple HandlerFunc that replies to each request +// DefaultMethodNotAllowedHandler is a simple HandlerFunc that replies to each request // with a “405 Method Not Allowed” reply. -func DefaultMethodNotAllowedHandler() HandlerFunc { - return func(c Context) { - http.Error(c.Writer(), http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - } +func DefaultMethodNotAllowedHandler(c Context) { + http.Error(c.Writer(), http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) } -// DefaultOptionsHandler returns a simple HandlerFunc that replies to each request with a "200 OK" reply. -func DefaultOptionsHandler() HandlerFunc { - return func(c Context) { - c.Writer().WriteHeader(http.StatusOK) - } +// DefaultOptionsHandler is a simple HandlerFunc that replies to each request with a "200 OK" reply. +func DefaultOptionsHandler(c Context) { + c.Writer().WriteHeader(http.StatusOK) } func defaultRedirectTrailingSlashHandler(c Context) { diff --git a/fox_test.go b/fox_test.go index d9acc1f..779613e 100644 --- a/fox_test.go +++ b/fox_test.go @@ -1767,6 +1767,9 @@ func TestRouterWithIgnoreTrailingSlash(t *testing.T) { require.NoError(t, r.Tree().Handle(tc.method, path, func(c Context) { _ = c.String(http.StatusOK, c.Path()) })) + rte := r.Tree().Route(tc.method, path) + require.NotNil(t, rte) + assert.True(t, rte.IgnoreTrailingSlashEnabled()) } req := httptest.NewRequest(tc.method, tc.req, nil) @@ -1784,7 +1787,11 @@ func TestRouterWithClientIPStrategy(t *testing.T) { f := New(WithClientIPStrategy(ClientIPStrategyFunc(func(c Context) (*net.IPAddr, error) { return c.RemoteIP(), nil }))) - require.True(t, f.ClientIPStrategyEnabled()) + f.MustHandle(http.MethodGet, "/foo", emptyHandler) + assert.True(t, f.ClientIPStrategyEnabled()) + rte := f.Tree().Route(http.MethodGet, "/foo") + require.NotNil(t, rte) + assert.True(t, rte.ClientIPStrategyEnabled()) } func TestRedirectTrailingSlash(t *testing.T) { @@ -1918,6 +1925,9 @@ func TestRedirectTrailingSlash(t *testing.T) { require.True(t, r.RedirectTrailingSlashEnabled()) for _, path := range tc.paths { require.NoError(t, r.Tree().Handle(tc.method, path, emptyHandler)) + rte := r.Tree().Route(tc.method, path) + require.NotNil(t, rte) + assert.True(t, rte.RedirectTrailingSlashEnabled()) } req := httptest.NewRequest(tc.method, tc.req, nil) @@ -2537,6 +2547,73 @@ func TestWithScopedMiddleware(t *testing.T) { assert.True(t, called) } +func TestUpdateWithMiddleware(t *testing.T) { + called := false + m := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + called = true + next(c) + } + }) + f := New() + f.MustHandle(http.MethodGet, "/foo", emptyHandler) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + w := httptest.NewRecorder() + + // Add middleware + require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler, WithMiddleware(m))) + f.ServeHTTP(w, req) + assert.True(t, called) + called = false + + // Remove middleware + require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler)) + f.ServeHTTP(w, req) + assert.False(t, called) +} + +func TestRouteMiddleware(t *testing.T) { + var c0, c1, c2 bool + m0 := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + c0 = true + next(c) + } + }) + + m1 := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + c1 = true + next(c) + } + }) + + m2 := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + c2 = true + next(c) + } + }) + f := New(WithMiddleware(m0)) + f.MustHandle(http.MethodGet, "/1", emptyHandler, WithMiddleware(m1)) + f.MustHandle(http.MethodGet, "/2", emptyHandler, WithMiddleware(m2)) + + req := httptest.NewRequest(http.MethodGet, "/1", nil) + w := httptest.NewRecorder() + + f.ServeHTTP(w, req) + assert.True(t, c0) + assert.True(t, c1) + assert.False(t, c2) + c0, c1, c2 = false, false, false + + req.URL.Path = "/2" + f.ServeHTTP(w, req) + assert.True(t, c0) + assert.False(t, c1) + assert.True(t, c2) +} + func TestWithNotFoundHandler(t *testing.T) { notFound := func(c Context) { _ = c.String(http.StatusNotFound, "NOT FOUND\n") diff --git a/response_writer.go b/response_writer.go index aa4ff9f..66841f9 100644 --- a/response_writer.go +++ b/response_writer.go @@ -162,7 +162,7 @@ func (r *recorder) FlushError() error { flusher.Flush() return nil default: - return errNotSupported() + return ErrNotSupported() } } @@ -181,7 +181,7 @@ func (r *recorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hijacker, ok := r.ResponseWriter.(http.Hijacker); ok { return hijacker.Hijack() } - return nil, nil, errNotSupported() + return nil, nil, ErrNotSupported() } type noUnwrap struct { @@ -225,6 +225,7 @@ func relevantCaller() runtime.Frame { return frame } -func errNotSupported() error { +// ErrNotSupported returns an error that Is ErrNotSupported, but is not == to it. +func ErrNotSupported() error { return fmt.Errorf("%w", http.ErrNotSupported) } diff --git a/tree.go b/tree.go index 22507b0..277ecd5 100644 --- a/tree.go +++ b/tree.go @@ -95,24 +95,32 @@ func (t *Tree) Remove(method, path string) error { return nil } -// Has allows to check if the given method and path exactly match a registered route. This function is safe for concurrent -// use by multiple goroutine and while mutation on Tree are ongoing. +// Has allows to check if the given method and path exactly match a registered route. This function is safe for +// concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Has(method, path string) bool { + return t.Route(method, path) != nil +} + +// Route performs a lookup for a registered route matching the given method and path. It returns the route if a +// match is found or nil otherwise. This function is safe for concurrent use by multiple goroutine and while +// mutation on Tree are ongoing. +// This API is EXPERIMENTAL and is likely to change in future release. +func (t *Tree) Route(method, path string) *Route { nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return false + return nil } c := t.ctx.Get().(*cTx) c.resetNil() n, tsr := t.lookup(nds[index], path, c, true) c.Close() - if n != nil && !tsr { - return n.route.path == path + if n != nil && !tsr && n.route.path == path { + return n.route } - return false + return nil } // Match perform a reverse lookup on the tree for the given method and path and return the matching registered route if any. When