Skip to content

Commit

Permalink
feat(response_writer): use unwrap to safely assert ResponseWriter cap…
Browse files Browse the repository at this point in the history
…abilities
  • Loading branch information
tigerwill90 committed Feb 15, 2024
1 parent 6b8d3ab commit 241cd1b
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 153 deletions.
16 changes: 4 additions & 12 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ type Context interface {
Tree() *Tree
// Fox returns the Router in use to serve the request.
Fox() *Router
// Reset resets the Context to its initial state, attaching the provided Router,
// http.ResponseWriter, and *http.Request.
Reset(fox *Router, w http.ResponseWriter, r *http.Request)
}

// context holds request-related information and allows interaction with the ResponseWriter.
Expand All @@ -101,17 +98,12 @@ type context struct {
rec recorder
}

// Reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request.
// Caution: You should pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to avoid
// wrapping the ResponseWriter within itself.
func (c *context) Reset(fox *Router, w http.ResponseWriter, r *http.Request) {
// reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request.
// Caution: the only valid http.ResponseWriter this method may consume is the original own (from ServeHTTP).
func (c *context) reset(fox *Router, w http.ResponseWriter, r *http.Request) {
c.rec.reset(w)
c.req = r
if r.ProtoMajor == 2 {
c.w = h2Writer{&c.rec}
} else {
c.w = h1Writer{&c.rec}
}
c.w = &c.rec
c.fox = fox
c.path = ""
c.cachedQuery = nil
Expand Down
4 changes: 3 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ func TestContext_Writer(t *testing.T) {
assert.Equal(t, http.StatusCreated, c.Writer().Status())
assert.Equal(t, buf, w.Body.Bytes())
assert.Equal(t, len(buf), c.Writer().Size())
assert.Equal(t, w, c.Writer().(interface{ Unwrap() http.ResponseWriter }).Unwrap())
rw := c.Writer().(interface{ Unwrap() http.ResponseWriter }).Unwrap()
assert.Implements(t, (*http.Flusher)(nil), rw)
assert.IsType(t, flushWriter{}, rw)
assert.True(t, c.Writer().Written())
}

Expand Down
46 changes: 30 additions & 16 deletions fox.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func New(opts ...Option) *Router {

r.noRoute = applyMiddleware(NoRouteHandler, r.mws, r.noRoute)
r.noMethod = applyMiddleware(NoMethodHandler, r.mws, r.noMethod)
r.tsrRedirect = applyMiddleware(RedirectHandler, r.mws, defaultRedirectTrailingSlash)
r.tsrRedirect = applyMiddleware(RedirectHandler, r.mws, defaultRedirectTrailingSlashHandler)
r.autoOptions = applyMiddleware(OptionsHandler, r.mws, r.autoOptions)

r.tree.Store(r.NewTree())
Expand Down Expand Up @@ -159,24 +159,38 @@ func (fox *Router) Remove(method, path string) error {
return t.Remove(method, path)
}

// Reverse perform a lookup on the tree for the given method and path and return the matching registered route if any.
// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing.
// Lookup performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a ContextCloser,
// and a boolean indicating if a trailing slash redirect is recommended. The ContextCloser should always be closed if non-nil.
// This method is primarily intended for integrating the fox router into custom routing solutions. It requires the use of the
// original http.ResponseWriter, typically obtained from ServeHTTP. This function is safe for concurrent use by multiple goroutine
// and while mutation on Tree are ongoing.
// This API is EXPERIMENTAL and is likely to change in future release.
func Reverse(t *Tree, method, path string) string {
nds := *t.nodes.Load()
index := findRootNode(method, nds)
func (fox *Router) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) {
tree := fox.tree.Load()

nds := *tree.nodes.Load()
index := findRootNode(r.Method, nds)

if index < 0 {
return ""
return

Check warning on line 175 in fox.go

View check run for this annotation

Codecov / codecov/patch

fox.go#L175

Added line #L175 was not covered by tests
}

c := t.ctx.Get().(*context)
c.resetNil()
n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true)
c.Close()
if n == nil {
return ""
c := tree.ctx.Get().(*context)
c.reset(fox, w, r)

target := r.URL.Path
if len(r.URL.RawPath) > 0 {
// Using RawPath to prevent unintended match (e.g. /search/a%2Fb/1)
target = r.URL.RawPath
}
return n.path

n, tsr := tree.lookup(nds[index], target, c.params, c.skipNds, false)
if n != nil {
c.path = n.path
return n.handler, c, tsr
}
c.Close()
return nil, nil, tsr

Check warning on line 193 in fox.go

View check run for this annotation

Codecov / codecov/patch

fox.go#L192-L193

Added lines #L192 - L193 were not covered by tests
}

// SkipMethod is used as a return value from WalkFunc to indicate that
Expand Down Expand Up @@ -226,7 +240,7 @@ func DefaultMethodNotAllowedHandler() HandlerFunc {
}
}

func defaultRedirectTrailingSlash(c Context) {
func defaultRedirectTrailingSlashHandler(c Context) {
req := c.Request()

code := http.StatusMovedPermanently
Expand Down Expand Up @@ -275,7 +289,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {

tree := fox.tree.Load()
c := tree.ctx.Get().(*context)
c.Reset(fox, w, r)
c.reset(fox, w, r)

nds := *tree.nodes.Load()
index := findRootNode(r.Method, nds)
Expand Down
59 changes: 22 additions & 37 deletions fox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"net/http/httptest"
"reflect"
"regexp"
"sort"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -1745,16 +1744,6 @@ func TestTree_RemoveRoot(t *testing.T) {
assert.Equal(t, 4, len(*tree.nodes.Load()))
}

func TestTree_Methods(t *testing.T) {
tree := New().Tree()
methods := []string{"GET", "POST", "PATCH"}
for _, m := range methods {
require.NoError(t, tree.Handle(m, "/foo/bar", emptyHandler))
}
sort.Strings(methods)
assert.Equal(t, methods, tree.Methods())
}

func TestRouterWithAllowedMethod(t *testing.T) {
r := New(WithNoMethod(true))

Expand Down Expand Up @@ -1862,7 +1851,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) {
target: "/foo",
path: "/foo",
methods: []string{"GET", "TRACE", "PUT", "OPTIONS"},
want: "GET, PUT, TRACE, OPTIONS",
want: "GET, OPTIONS, PUT, TRACE",
wantCode: http.StatusNoContent,
},
{
Expand All @@ -1878,7 +1867,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
for _, method := range tc.methods {
require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) {
c.SetHeader("Allow", strings.Join(c.Tree().LookupMethods(c.Request().URL.Path), ", "))
c.SetHeader("Allow", strings.Join(c.Tree().Methods(c.Request().URL.Path), ", "))
c.Writer().WriteHeader(http.StatusNoContent)
}))
}
Expand All @@ -1887,13 +1876,25 @@ func TestRouterWithAutomaticOptions(t *testing.T) {
f.ServeHTTP(w, req)
assert.Equal(t, tc.wantCode, w.Code)
assert.Equal(t, tc.want, w.Header().Get("Allow"))

// Reset
f.Swap(f.NewTree())
})
}
}

func TestTree_Methods(t *testing.T) {
f := New()
for _, rte := range githubAPI {
require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler))
}

methods := f.Tree().Methods("/gists/123/star")
assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods)

methods = f.Tree().Methods("*")
assert.Equal(t, []string{"DELETE", "GET", "POST", "PUT"}, methods)
}

func TestRouterWithOptionsHandler(t *testing.T) {
f := New(WithOptionsHandler(func(c Context) {
assert.Equal(t, "/foo/bar", c.Path())
Expand Down Expand Up @@ -1966,9 +1967,9 @@ func TestTree_Lookup(t *testing.T) {
require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler))
}

tree := f.Tree()
for _, rte := range githubAPI {
handler, cc, _ := tree.Lookup(rte.method, rte.path, false)
req := httptest.NewRequest(rte.method, rte.path, nil)
handler, cc, _ := f.Lookup(mockResponseWriter{}, req)
require.NotNil(t, cc)
assert.NotNil(t, handler)

Expand All @@ -1988,16 +1989,6 @@ func TestTree_Lookup(t *testing.T) {
}
}

func TestTree_LookupMethods(t *testing.T) {
f := New()
for _, rte := range githubAPI {
require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler))
}

methods := f.Tree().LookupMethods("/gists/123/star")
assert.Equal(t, []string{"GET", "PUT", "DELETE"}, methods)
}

func TestHas(t *testing.T) {
routes := []string{
"/foo/bar",
Expand Down Expand Up @@ -2052,7 +2043,7 @@ func TestHas(t *testing.T) {
}
}

func TestReverse(t *testing.T) {
func TestMatch(t *testing.T) {
routes := []string{
"/foo/bar",
"/welcome/{name}",
Expand Down Expand Up @@ -2092,7 +2083,7 @@ func TestReverse(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, Reverse(r.Tree(), http.MethodGet, tc.path))
assert.Equal(t, tc.want, r.Tree().Match(http.MethodGet, tc.path))
})
}
}
Expand Down Expand Up @@ -2446,20 +2437,14 @@ func ExampleRouter_Tree() {
// This example demonstrates how to create a custom middleware that cleans the request path and performs a manual
// lookup on the tree. If the cleaned path matches a registered route, the client is redirected with a 301 status
// code (Moved Permanently).
func ExampleTree_Lookup() {
func ExampleTree_Match() {
redirectFixedPath := MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
return func(c Context) {
req := c.Request()

cleanedPath := CleanPath(req.URL.Path)
handler, cc, _ := c.Tree().Lookup(req.Method, cleanedPath, true)
// You should always close a non-nil Context.
if cc != nil {
defer cc.Close()
}

// 301 redirect and returns.
if handler != nil {
if match := c.Tree().Match(req.Method, cleanedPath); match != "" {
// 301 redirect and returns.
req.URL.Path = cleanedPath
http.Redirect(c.Writer(), req, req.URL.String(), http.StatusMovedPermanently)
return
Expand Down
2 changes: 1 addition & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *co
c.fox = fox
c.req = r
c.rec.reset(w)
c.w = flushWriter{&c.rec}
c.w = &c.rec
return c
}

Expand Down
12 changes: 11 additions & 1 deletion helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ func TestNewTestContext(t *testing.T) {
w := httptest.NewRecorder()
_, c := NewTestContext(w, req)

flusher, ok := c.Writer().(http.Flusher)
require.Implements(t, (*interface{ Unwrap() http.ResponseWriter })(nil), c.Writer())
rw := c.Writer().(interface{ Unwrap() http.ResponseWriter }).Unwrap()
assert.IsType(t, flushWriter{}, rw)

flusher, ok := rw.(http.Flusher)
require.True(t, ok)

n, err := c.Writer().Write([]byte("foo"))
Expand All @@ -36,6 +40,12 @@ func TestNewTestContext(t *testing.T) {
_, ok = c.Writer().(http.Hijacker)
assert.False(t, ok)

_, ok = rw.(http.Hijacker)
assert.False(t, ok)

_, ok = c.Writer().(io.ReaderFrom)
assert.False(t, ok)

_, ok = rw.(io.ReaderFrom)
assert.False(t, ok)
}
Loading

0 comments on commit 241cd1b

Please sign in to comment.