From 241cd1b4e412dfb51ef5aed07b0edea9d99c9b1a Mon Sep 17 00:00:00 2001 From: tigerwill90 Date: Thu, 15 Feb 2024 09:43:57 +0100 Subject: [PATCH] feat(response_writer): use unwrap to safely assert ResponseWriter capabilities --- context.go | 16 ++--- context_test.go | 4 +- fox.go | 46 ++++++++----- fox_test.go | 59 +++++++---------- helpers.go | 2 +- helpers_test.go | 12 +++- response_writer.go | 160 ++++++++++++++++++++++++++++++++++++--------- tree.go | 102 ++++++++++++++--------------- 8 files changed, 248 insertions(+), 153 deletions(-) diff --git a/context.go b/context.go index b3d16c0..9e2b18c 100644 --- a/context.go +++ b/context.go @@ -80,9 +80,6 @@ type Context interface { Tree() *Tree // Fox returns the Router in use to serve the request. Fox() *Router - // Reset resets the Context to its initial state, attaching the provided Router, - // http.ResponseWriter, and *http.Request. - Reset(fox *Router, w http.ResponseWriter, r *http.Request) } // context holds request-related information and allows interaction with the ResponseWriter. @@ -101,17 +98,12 @@ type context struct { rec recorder } -// Reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request. -// Caution: You should pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to avoid -// wrapping the ResponseWriter within itself. -func (c *context) Reset(fox *Router, w http.ResponseWriter, r *http.Request) { +// reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request. +// Caution: the only valid http.ResponseWriter this method may consume is the original own (from ServeHTTP). +func (c *context) reset(fox *Router, w http.ResponseWriter, r *http.Request) { c.rec.reset(w) c.req = r - if r.ProtoMajor == 2 { - c.w = h2Writer{&c.rec} - } else { - c.w = h1Writer{&c.rec} - } + c.w = &c.rec c.fox = fox c.path = "" c.cachedQuery = nil diff --git a/context_test.go b/context_test.go index 8975dd9..796d643 100644 --- a/context_test.go +++ b/context_test.go @@ -169,7 +169,9 @@ func TestContext_Writer(t *testing.T) { assert.Equal(t, http.StatusCreated, c.Writer().Status()) assert.Equal(t, buf, w.Body.Bytes()) assert.Equal(t, len(buf), c.Writer().Size()) - assert.Equal(t, w, c.Writer().(interface{ Unwrap() http.ResponseWriter }).Unwrap()) + rw := c.Writer().(interface{ Unwrap() http.ResponseWriter }).Unwrap() + assert.Implements(t, (*http.Flusher)(nil), rw) + assert.IsType(t, flushWriter{}, rw) assert.True(t, c.Writer().Written()) } diff --git a/fox.go b/fox.go index 6ad2f05..78dd498 100644 --- a/fox.go +++ b/fox.go @@ -71,7 +71,7 @@ func New(opts ...Option) *Router { r.noRoute = applyMiddleware(NoRouteHandler, r.mws, r.noRoute) r.noMethod = applyMiddleware(NoMethodHandler, r.mws, r.noMethod) - r.tsrRedirect = applyMiddleware(RedirectHandler, r.mws, defaultRedirectTrailingSlash) + r.tsrRedirect = applyMiddleware(RedirectHandler, r.mws, defaultRedirectTrailingSlashHandler) r.autoOptions = applyMiddleware(OptionsHandler, r.mws, r.autoOptions) r.tree.Store(r.NewTree()) @@ -159,24 +159,38 @@ func (fox *Router) Remove(method, path string) error { return t.Remove(method, path) } -// Reverse perform a lookup on the tree for the given method and path and return the matching registered route if any. -// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Lookup performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a ContextCloser, +// and a boolean indicating if a trailing slash redirect is recommended. The ContextCloser should always be closed if non-nil. +// This method is primarily intended for integrating the fox router into custom routing solutions. It requires the use of the +// original http.ResponseWriter, typically obtained from ServeHTTP. This function is safe for concurrent use by multiple goroutine +// and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func Reverse(t *Tree, method, path string) string { - nds := *t.nodes.Load() - index := findRootNode(method, nds) +func (fox *Router) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { + tree := fox.tree.Load() + + nds := *tree.nodes.Load() + index := findRootNode(r.Method, nds) + if index < 0 { - return "" + return } - c := t.ctx.Get().(*context) - c.resetNil() - n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) - c.Close() - if n == nil { - return "" + c := tree.ctx.Get().(*context) + c.reset(fox, w, r) + + target := r.URL.Path + if len(r.URL.RawPath) > 0 { + // Using RawPath to prevent unintended match (e.g. /search/a%2Fb/1) + target = r.URL.RawPath } - return n.path + + n, tsr := tree.lookup(nds[index], target, c.params, c.skipNds, false) + if n != nil { + c.path = n.path + return n.handler, c, tsr + } + c.Close() + return nil, nil, tsr } // SkipMethod is used as a return value from WalkFunc to indicate that @@ -226,7 +240,7 @@ func DefaultMethodNotAllowedHandler() HandlerFunc { } } -func defaultRedirectTrailingSlash(c Context) { +func defaultRedirectTrailingSlashHandler(c Context) { req := c.Request() code := http.StatusMovedPermanently @@ -275,7 +289,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { tree := fox.tree.Load() c := tree.ctx.Get().(*context) - c.Reset(fox, w, r) + c.reset(fox, w, r) nds := *tree.nodes.Load() index := findRootNode(r.Method, nds) diff --git a/fox_test.go b/fox_test.go index 41f7b14..4afbbcb 100644 --- a/fox_test.go +++ b/fox_test.go @@ -12,7 +12,6 @@ import ( "net/http/httptest" "reflect" "regexp" - "sort" "strings" "sync" "sync/atomic" @@ -1745,16 +1744,6 @@ func TestTree_RemoveRoot(t *testing.T) { assert.Equal(t, 4, len(*tree.nodes.Load())) } -func TestTree_Methods(t *testing.T) { - tree := New().Tree() - methods := []string{"GET", "POST", "PATCH"} - for _, m := range methods { - require.NoError(t, tree.Handle(m, "/foo/bar", emptyHandler)) - } - sort.Strings(methods) - assert.Equal(t, methods, tree.Methods()) -} - func TestRouterWithAllowedMethod(t *testing.T) { r := New(WithNoMethod(true)) @@ -1862,7 +1851,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) { target: "/foo", path: "/foo", methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, - want: "GET, PUT, TRACE, OPTIONS", + want: "GET, OPTIONS, PUT, TRACE", wantCode: http.StatusNoContent, }, { @@ -1878,7 +1867,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { - c.SetHeader("Allow", strings.Join(c.Tree().LookupMethods(c.Request().URL.Path), ", ")) + c.SetHeader("Allow", strings.Join(c.Tree().Methods(c.Request().URL.Path), ", ")) c.Writer().WriteHeader(http.StatusNoContent) })) } @@ -1887,13 +1876,25 @@ func TestRouterWithAutomaticOptions(t *testing.T) { f.ServeHTTP(w, req) assert.Equal(t, tc.wantCode, w.Code) assert.Equal(t, tc.want, w.Header().Get("Allow")) - // Reset f.Swap(f.NewTree()) }) } } +func TestTree_Methods(t *testing.T) { + f := New() + for _, rte := range githubAPI { + require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler)) + } + + methods := f.Tree().Methods("/gists/123/star") + assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) + + methods = f.Tree().Methods("*") + assert.Equal(t, []string{"DELETE", "GET", "POST", "PUT"}, methods) +} + func TestRouterWithOptionsHandler(t *testing.T) { f := New(WithOptionsHandler(func(c Context) { assert.Equal(t, "/foo/bar", c.Path()) @@ -1966,9 +1967,9 @@ func TestTree_Lookup(t *testing.T) { require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler)) } - tree := f.Tree() for _, rte := range githubAPI { - handler, cc, _ := tree.Lookup(rte.method, rte.path, false) + req := httptest.NewRequest(rte.method, rte.path, nil) + handler, cc, _ := f.Lookup(mockResponseWriter{}, req) require.NotNil(t, cc) assert.NotNil(t, handler) @@ -1988,16 +1989,6 @@ func TestTree_Lookup(t *testing.T) { } } -func TestTree_LookupMethods(t *testing.T) { - f := New() - for _, rte := range githubAPI { - require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler)) - } - - methods := f.Tree().LookupMethods("/gists/123/star") - assert.Equal(t, []string{"GET", "PUT", "DELETE"}, methods) -} - func TestHas(t *testing.T) { routes := []string{ "/foo/bar", @@ -2052,7 +2043,7 @@ func TestHas(t *testing.T) { } } -func TestReverse(t *testing.T) { +func TestMatch(t *testing.T) { routes := []string{ "/foo/bar", "/welcome/{name}", @@ -2092,7 +2083,7 @@ func TestReverse(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, Reverse(r.Tree(), http.MethodGet, tc.path)) + assert.Equal(t, tc.want, r.Tree().Match(http.MethodGet, tc.path)) }) } } @@ -2446,20 +2437,14 @@ func ExampleRouter_Tree() { // This example demonstrates how to create a custom middleware that cleans the request path and performs a manual // lookup on the tree. If the cleaned path matches a registered route, the client is redirected with a 301 status // code (Moved Permanently). -func ExampleTree_Lookup() { +func ExampleTree_Match() { redirectFixedPath := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return func(c Context) { req := c.Request() cleanedPath := CleanPath(req.URL.Path) - handler, cc, _ := c.Tree().Lookup(req.Method, cleanedPath, true) - // You should always close a non-nil Context. - if cc != nil { - defer cc.Close() - } - - // 301 redirect and returns. - if handler != nil { + if match := c.Tree().Match(req.Method, cleanedPath); match != "" { + // 301 redirect and returns. req.URL.Path = cleanedPath http.Redirect(c.Writer(), req, req.URL.String(), http.StatusMovedPermanently) return diff --git a/helpers.go b/helpers.go index a782bf3..47a47e7 100644 --- a/helpers.go +++ b/helpers.go @@ -27,7 +27,7 @@ func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *co c.fox = fox c.req = r c.rec.reset(w) - c.w = flushWriter{&c.rec} + c.w = &c.rec return c } diff --git a/helpers_test.go b/helpers_test.go index 21661ad..22be5ef 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -19,7 +19,11 @@ func TestNewTestContext(t *testing.T) { w := httptest.NewRecorder() _, c := NewTestContext(w, req) - flusher, ok := c.Writer().(http.Flusher) + require.Implements(t, (*interface{ Unwrap() http.ResponseWriter })(nil), c.Writer()) + rw := c.Writer().(interface{ Unwrap() http.ResponseWriter }).Unwrap() + assert.IsType(t, flushWriter{}, rw) + + flusher, ok := rw.(http.Flusher) require.True(t, ok) n, err := c.Writer().Write([]byte("foo")) @@ -36,6 +40,12 @@ func TestNewTestContext(t *testing.T) { _, ok = c.Writer().(http.Hijacker) assert.False(t, ok) + _, ok = rw.(http.Hijacker) + assert.False(t, ok) + _, ok = c.Writer().(io.ReaderFrom) assert.False(t, ok) + + _, ok = rw.(io.ReaderFrom) + assert.False(t, ok) } diff --git a/response_writer.go b/response_writer.go index 8d46a3b..5b36e52 100644 --- a/response_writer.go +++ b/response_writer.go @@ -23,24 +23,26 @@ import ( ) var ( - _ http.Flusher = (*h1Writer)(nil) - _ http.Hijacker = (*h1Writer)(nil) - _ io.ReaderFrom = (*h1Writer)(nil) + _ http.ResponseWriter = (*h1Writer)(nil) + _ http.Flusher = (*h1Writer)(nil) + _ http.Hijacker = (*h1Writer)(nil) + _ io.ReaderFrom = (*h1Writer)(nil) ) var ( - _ http.Pusher = (*h2Writer)(nil) - _ http.Flusher = (*h2Writer)(nil) + _ http.ResponseWriter = (*h2Writer)(nil) + _ http.Pusher = (*h2Writer)(nil) + _ http.Flusher = (*h2Writer)(nil) ) var ( - _ ResponseWriter = (*flushWriter)(nil) - _ http.Flusher = (*flushWriter)(nil) + _ http.ResponseWriter = (*flushWriter)(nil) + _ http.Flusher = (*flushWriter)(nil) ) var ( - _ ResponseWriter = (*pushWriter)(nil) - _ http.Pusher = (*pushWriter)(nil) + _ http.ResponseWriter = (*pushWriter)(nil) + _ http.Pusher = (*pushWriter)(nil) ) // ResponseWriter extends http.ResponseWriter and provides methods to retrieve the recorded status code, @@ -74,14 +76,17 @@ func (r *recorder) reset(w http.ResponseWriter) { r.status = http.StatusOK } +// Status recorded after Write and WriteHeader. func (r *recorder) Status() int { return r.status } +// Written returns true if the response has been written. func (r *recorder) Written() bool { return r.size != notWritten } +// Size returns the size of the written response. func (r *recorder) Size() int { if r.size < 0 { return 0 @@ -89,10 +94,32 @@ func (r *recorder) Size() int { return r.size } +// Unwrap returns a compliant http.ResponseWriter, which safely supports additional interfaces such as http.Flusher +// or http.Hijacker. The exact scope of supported interfaces is determined by the capabilities +// of the http.ResponseWriter provided to the ServeHTTP function, func (r *recorder) Unwrap() http.ResponseWriter { - return r.ResponseWriter + switch r.ResponseWriter.(type) { + case interface { + http.Flusher + http.Hijacker + io.ReaderFrom + }: + return h1Writer{r} + case interface { + http.Flusher + http.Pusher + }: + return h2Writer{r} + case http.Pusher: + return pushWriter{r} + case http.Flusher: + return flushWriter{r} + } + return r } +// WriteHeader sends an HTTP response header with the provided +// status code. See http.ResponseWriter for more details. func (r *recorder) WriteHeader(code int) { if r.Written() { caller := relevantCaller() @@ -105,6 +132,8 @@ func (r *recorder) WriteHeader(code int) { r.ResponseWriter.WriteHeader(code) } +// Write writes the data to the connection as part of an HTTP reply. +// See http.ResponseWriter for more details. func (r *recorder) Write(buf []byte) (n int, err error) { if !r.Written() { r.size = 0 @@ -115,6 +144,9 @@ func (r *recorder) Write(buf []byte) (n int, err error) { return } +// WriteString writes the provided string to the underlying connection +// as part of an HTTP reply. The method returns the number of bytes written +// and an error, if any. func (r *recorder) WriteString(s string) (n int, err error) { if !r.Written() { r.size = 0 @@ -127,66 +159,130 @@ func (r *recorder) WriteString(s string) (n int, err error) { } type flushWriter struct { - *recorder + r *recorder +} + +func (w flushWriter) Header() http.Header { + return w.r.Header() +} + +func (w flushWriter) Write(buf []byte) (int, error) { + return w.r.Write(buf) +} + +func (w flushWriter) WriteHeader(statusCode int) { + w.r.WriteHeader(statusCode) +} + +func (w flushWriter) WriteString(s string) (int, error) { + return w.r.WriteString(s) } func (w flushWriter) Flush() { - if !w.recorder.Written() { - w.recorder.size = 0 + if !w.r.Written() { + w.r.size = 0 } - w.recorder.ResponseWriter.(http.Flusher).Flush() + w.r.ResponseWriter.(http.Flusher).Flush() } type h1Writer struct { - *recorder + r *recorder +} + +func (w h1Writer) Header() http.Header { + return w.r.Header() +} + +func (w h1Writer) Write(buf []byte) (int, error) { + return w.r.Write(buf) +} + +func (w h1Writer) WriteHeader(statusCode int) { + w.r.WriteHeader(statusCode) +} + +func (w h1Writer) WriteString(s string) (int, error) { + return w.r.WriteString(s) } func (w h1Writer) ReadFrom(src io.Reader) (n int64, err error) { - if !w.recorder.Written() { - w.recorder.size = 0 + if !w.r.Written() { + w.r.size = 0 } - rf := w.recorder.ResponseWriter.(io.ReaderFrom) + rf := w.r.ResponseWriter.(io.ReaderFrom) n, err = rf.ReadFrom(src) - w.recorder.size += int(n) + w.r.size += int(n) return } func (w h1Writer) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if !w.recorder.Written() { - w.recorder.size = 0 + if !w.r.Written() { + w.r.size = 0 } - return w.recorder.ResponseWriter.(http.Hijacker).Hijack() + return w.r.ResponseWriter.(http.Hijacker).Hijack() } func (w h1Writer) Flush() { - if !w.recorder.Written() { - w.recorder.size = 0 + if !w.r.Written() { + w.r.size = 0 } - w.recorder.ResponseWriter.(http.Flusher).Flush() + w.r.ResponseWriter.(http.Flusher).Flush() } type h2Writer struct { - *recorder + r *recorder +} + +func (w h2Writer) Header() http.Header { + return w.r.Header() +} + +func (w h2Writer) Write(buf []byte) (int, error) { + return w.r.Write(buf) +} + +func (w h2Writer) WriteHeader(statusCode int) { + w.r.WriteHeader(statusCode) +} + +func (w h2Writer) WriteString(s string) (int, error) { + return w.r.WriteString(s) } func (w h2Writer) Push(target string, opts *http.PushOptions) error { - return w.recorder.ResponseWriter.(http.Pusher).Push(target, opts) + return w.r.ResponseWriter.(http.Pusher).Push(target, opts) } func (w h2Writer) Flush() { - if !w.recorder.Written() { - w.recorder.size = 0 + if !w.r.Written() { + w.r.size = 0 } - w.recorder.ResponseWriter.(http.Flusher).Flush() + w.r.ResponseWriter.(http.Flusher).Flush() } type pushWriter struct { - *recorder + r *recorder +} + +func (w pushWriter) Header() http.Header { + return w.r.Header() +} + +func (w pushWriter) Write(buf []byte) (int, error) { + return w.r.Write(buf) +} + +func (w pushWriter) WriteHeader(statusCode int) { + w.r.WriteHeader(statusCode) +} + +func (w pushWriter) WriteString(s string) (int, error) { + return w.r.WriteString(s) } func (w pushWriter) Push(target string, opts *http.PushOptions) error { - return w.recorder.ResponseWriter.(http.Pusher).Push(target, opts) + return w.r.ResponseWriter.(http.Pusher).Push(target, opts) } // noUnwrap hide the Unwrap method of the ResponseWriter. diff --git a/tree.go b/tree.go index 758a41e..746f55e 100644 --- a/tree.go +++ b/tree.go @@ -80,82 +80,78 @@ func (t *Tree) Remove(method, path string) error { return nil } -// Methods returns a sorted slice of HTTP methods that are currently in use to route requests. -// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Has allows to check if the given method and path exactly match a registered route. This function is safe for +// concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Methods() []string { - var methods []string +func (t *Tree) Has(method, path string) bool { nds := *t.nodes.Load() - for i := range nds { - if len(nds[i].children) > 0 { - if methods == nil { - methods = make([]string, 0) - } - methods = append(methods, nds[i].key) - } + index := findRootNode(method, nds) + if index < 0 { + return false } - sort.Strings(methods) - return methods + + c := t.ctx.Get().(*context) + c.resetNil() + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() + return n != nil && n.path == path } -// Lookup allow to do manual lookup of a route for the given method and path and return the matched HandlerFunc -// along with a ContextCloser and trailing slash redirect recommendation. If lazy is set to true, wildcard parameter are -// not parsed. You should always close the ContextCloser if NOT nil by calling cc.Close(). Note that the returned -// ContextCloser does not have a router, request and response writer attached (use the Reset method). +// Match perform a lookup on the tree for the given method and path and return the matching registered route if any. // This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Lookup(method, path string, lazy bool) (handler HandlerFunc, cc ContextCloser, tsr bool) { +func (t *Tree) Match(method, path string) string { nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return + return "" } c := t.ctx.Get().(*context) c.resetNil() - n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, lazy) - if n != nil { - c.path = n.path - return n.handler, c, tsr + n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + c.Close() + if n == nil { + return "" } - return nil, c, tsr + return n.path } -// LookupMethods lookup and returns all HTTP methods associated with a route that match the given path. This function -// is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Methods returns a sorted list of HTTP methods associated with a given path in the routing tree. If the path is "*", +// it returns all HTTP methods that have at least one route registered in the tree. For a specific path, it returns the methods +// that can route requests to that path. +// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) LookupMethods(path string) (methods []string) { +func (t *Tree) Methods(path string) []string { + var methods []string nds := *t.nodes.Load() - c := t.ctx.Get().(*context) - c.resetNil() - - methods = make([]string, 0) - for i := 0; i < len(nds); i++ { - n, _ := t.lookup(nds[i], path, c.params, c.skipNds, true) - if n != nil { - methods = append(methods, nds[i].key) + if path == "*" { + for i := range nds { + if len(nds[i].children) > 0 { + if methods == nil { + methods = make([]string, 0) + } + methods = append(methods, nds[i].key) + } } - } - c.Close() - return methods -} - -// Has allows to check if the given method and path exactly match a registered route. This function is safe for -// concurrent use by multiple goroutine and while mutation on Tree are ongoing. -// This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Has(method, path string) bool { - nds := *t.nodes.Load() - index := findRootNode(method, nds) - if index < 0 { - return false + } else { + c := t.ctx.Get().(*context) + c.resetNil() + for i := range nds { + n, _ := t.lookup(nds[i], path, c.params, c.skipNds, true) + if n != nil { + if methods == nil { + methods = make([]string, 0) + } + methods = append(methods, nds[i].key) + } + } + c.Close() } - c := t.ctx.Get().(*context) - c.resetNil() - n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) - c.Close() - return n != nil && n.path == path + sort.Strings(methods) + return methods } // Insert is not safe for concurrent use. The path must start by '/' and it's not validated. Use