Skip to content

Commit

Permalink
feat(txn): improve Context api
Browse files Browse the repository at this point in the history
  • Loading branch information
tigerwill90 committed Nov 17, 2024
1 parent 4fb8fac commit 3c29ca8
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 79 deletions.
31 changes: 17 additions & 14 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ type Context interface {
Params() iter.Seq[Param]
// Param retrieve a matching wildcard parameter by name.
Param(name string) string
// Path returns the request URL path.
Path() string
// Host returns the request host.
Host() string
// QueryParams parses the [http.Request] raw query and returns the corresponding values.
QueryParams() url.Values
// QueryParam returns the first query value associated with the given key.
Expand All @@ -85,8 +89,6 @@ type Context interface {
// Scope returns the [HandlerScope] associated with the current [Context].
// This indicates the scope in which the handler is being executed, such as [RouteHandler], [NoRouteHandler], etc.
Scope() HandlerScope
// Tree is a local copy of the [Tree] in use to serve the request.
Tree() *Tree
// Fox returns the [Router] instance.
Fox() *Router
}
Expand All @@ -99,9 +101,6 @@ type cTx struct {
tsrParams *Params
skipNds *skippedNodes
route *Route

// tree at allocation (read-only, no reset)
tree *Tree
// router at allocation (read-only, no reset)
fox *Router
cachedQuery url.Values
Expand Down Expand Up @@ -227,6 +226,16 @@ func (c *cTx) Param(name string) string {
return ""
}

// Path returns the request URL path.
func (c *cTx) Path() string {
return c.req.URL.Path
}

// Host returns the request host.
func (c *cTx) Host() string {
return c.req.Host
}

// QueryParams parses the [http.Request] raw query and returns the corresponding values.
func (c *cTx) QueryParams() url.Values {
return c.getQueries()
Expand Down Expand Up @@ -295,11 +304,6 @@ func (c *cTx) Redirect(code int, url string) error {
return nil
}

// Tree is a local copy of the [Tree] in use to serve the request.
func (c *cTx) Tree() *Tree {
return c.tree
}

// Fox returns the [Router] instance.
func (c *cTx) Fox() *Router {
return c.fox
Expand All @@ -312,7 +316,6 @@ func (c *cTx) Clone() Context {
rec: c.rec,
req: c.req.Clone(c.req.Context()),
fox: c.fox,
tree: c.tree,
route: c.route,
scope: c.scope,
tsr: c.tsr,
Expand All @@ -339,7 +342,7 @@ func (c *cTx) Clone() Context {
// be closed once no longer needed. This functionality is particularly beneficial for middlewares that need to wrap
// their custom [ResponseWriter] while preserving the state of the original [Context].
func (c *cTx) CloneWith(w ResponseWriter, r *http.Request) ContextCloser {
cp := c.tree.ctx.Get().(*cTx)
cp := c.fox.tree.ctx.Get().(*cTx)
cp.req = r
cp.w = w
cp.route = c.route
Expand All @@ -366,10 +369,10 @@ func (c *cTx) Scope() HandlerScope {
func (c *cTx) Close() {
// Put back the context, if not extended more than max params or max depth, allowing
// the slice to naturally grow within the constraint.
if cap(*c.params) > int(c.tree.maxParams.Load()) || cap(*c.skipNds) > int(c.tree.maxDepth.Load()) {
if cap(*c.params) > int(c.fox.tree.maxParams.Load()) || cap(*c.skipNds) > int(c.fox.tree.maxDepth.Load()) {
return
}
c.tree.ctx.Put(c)
c.fox.tree.ctx.Put(c)
}

func (c *cTx) getQueries() url.Values {
Expand Down
39 changes: 26 additions & 13 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,32 @@ func TestContext_Route(t *testing.T) {
assert.Equal(t, "/foo", w.Body.String())
}

func TestContext_Path(t *testing.T) {
t.Parallel()
f := New()
f.MustHandle(http.MethodGet, "/{a}", func(c Context) {
_, _ = io.WriteString(c.Writer(), c.Path())
})

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
f.ServeHTTP(w, r)
assert.Equal(t, "/foo", w.Body.String())
}

func TestContext_Host(t *testing.T) {
t.Parallel()
f := New()
f.MustHandle(http.MethodGet, "/{a}", func(c Context) {
_, _ = io.WriteString(c.Writer(), c.Host())
})

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
f.ServeHTTP(w, r)
assert.Equal(t, "example.com", w.Body.String())
}

func TestContext_Annotations(t *testing.T) {
t.Parallel()
f := New()
Expand Down Expand Up @@ -295,19 +321,6 @@ func TestContext_Fox(t *testing.T) {
f.ServeHTTP(w, req)
}

func TestContext_Tree(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/foo", nil)

f := New()
require.NoError(t, onlyError(f.Handle(http.MethodGet, "/foo", func(c Context) {
assert.NotNil(t, c.Tree())
})))

f.ServeHTTP(w, req)
}

func TestContext_Scope(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion fox.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Put back the context, if not extended more than max params or max depth, allowing
// the slice to naturally grow within the constraint.
if cap(*c.params) <= int(fox.tree.maxParams.Load()) && cap(*c.skipNds) <= int(fox.tree.maxDepth.Load()) {
c.tree.ctx.Put(c)
fox.tree.ctx.Put(c)
}
return
}
Expand Down
55 changes: 10 additions & 45 deletions fox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

var emptyHandler = HandlerFunc(func(c Context) {})
var pathHandler = HandlerFunc(func(c Context) { _ = c.String(200, c.Request().URL.Path) })
var pathHandler = HandlerFunc(func(c Context) { _ = c.String(200, c.Path()) })

type mockResponseWriter struct{}

Expand Down Expand Up @@ -739,7 +739,7 @@ func TestParamsRoute(t *testing.T) {
rx := regexp.MustCompile("({|\\*{)[A-z]+[}]")
r := New()
h := func(c Context) {
matches := rx.FindAllString(c.Request().URL.Path, -1)
matches := rx.FindAllString(c.Path(), -1)
for _, match := range matches {
var key string
if strings.HasPrefix(match, "*") {
Expand All @@ -750,8 +750,8 @@ func TestParamsRoute(t *testing.T) {
value := match
assert.Equal(t, value, c.Param(key))
}
assert.Equal(t, c.Request().URL.Path, c.Pattern())
_ = c.String(200, c.Request().URL.Path)
assert.Equal(t, c.Path(), c.Pattern())
_ = c.String(200, c.Path())
}
for _, route := range githubAPI {
require.NoError(t, onlyError(r.Tree().Handle(route.method, route.path, h)))
Expand All @@ -769,7 +769,7 @@ func TestParamsRouteWithDomain(t *testing.T) {
rx := regexp.MustCompile("({|\\*{)[A-z]+[}]")
r := New()
h := func(c Context) {
matches := rx.FindAllString(c.Request().URL.Path, -1)
matches := rx.FindAllString(c.Path(), -1)
for _, match := range matches {
var key string
if strings.HasPrefix(match, "*") {
Expand All @@ -781,8 +781,8 @@ func TestParamsRouteWithDomain(t *testing.T) {
assert.Equal(t, value, c.Param(key))
}

assert.Equal(t, netutil.StripHostPort(c.Request().Host)+c.Request().URL.Path, c.Pattern())
_ = c.String(200, netutil.StripHostPort(c.Request().Host)+c.Request().URL.Path)
assert.Equal(t, netutil.StripHostPort(c.Host())+c.Path(), c.Pattern())
_ = c.String(200, netutil.StripHostPort(c.Host())+c.Path())
}
for _, route := range githubAPI {
require.NoError(t, onlyError(r.Tree().Handle(route.method, "foo.{bar}.com"+route.path, h)))
Expand Down Expand Up @@ -4528,7 +4528,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) {
require.True(t, f.AutoOptionsEnabled())
for _, method := range tc.methods {
require.NoError(t, onlyError(f.Tree().Handle(method, tc.path, func(c Context) {
c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().Host, c.Request().URL.Path))), ", "))
c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Fox().Iter().Reverse(c.Fox().Iter().Methods(), c.Host(), c.Path()))), ", "))
c.Writer().WriteHeader(http.StatusNoContent)
})))
}
Expand Down Expand Up @@ -4601,7 +4601,7 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionEnable(t *testing.T) {
f := New(WithAutoOptions(true), WithIgnoreTrailingSlash(true))
for _, method := range tc.methods {
require.NoError(t, onlyError(f.Tree().Handle(method, tc.path, func(c Context) {
c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().Host, c.Request().URL.Path))), ", "))
c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Fox().Iter().Reverse(c.Fox().Iter().Methods(), c.Host(), c.Path()))), ", "))
c.Writer().WriteHeader(http.StatusNoContent)
})))
}
Expand Down Expand Up @@ -4643,7 +4643,7 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) {
f := New(WithAutoOptions(true))
for _, method := range tc.methods {
require.NoError(t, onlyError(f.Tree().Handle(method, tc.path, func(c Context) {
c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().Host, c.Request().URL.Path))), ", "))
c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Fox().Iter().Reverse(c.Fox().Iter().Methods(), c.Host(), c.Path()))), ", "))
c.Writer().WriteHeader(http.StatusNoContent)
})))
}
Expand Down Expand Up @@ -5510,41 +5510,6 @@ func ExampleWithMiddleware() {
})
}

// This example demonstrates some important considerations when using the Tree API.
func ExampleRouter_Tree() {
r := New()

// Each tree as its own sync.Mutex that is used to lock write on the tree. Since the router tree may be swapped at
// any given time, you MUST always copy the pointer locally, This ensures that you do not inadvertently cause a
// deadlock by locking/unlocking the wrong tree.
tree := r.Tree()
upsert := func(method, path string, handler HandlerFunc) (*Route, error) {
tree.Lock()
defer tree.Unlock()
if tree.Has(method, path) {
return tree.Update(method, path, handler)
}
return tree.Handle(method, path, handler)
}

_, _ = upsert(http.MethodGet, "/foo/bar", func(c Context) {
// Note the tree accessible from fox.Context is already a local copy so the golden rule above does not apply.
c.Tree().Lock()
defer c.Tree().Unlock()
_ = c.String(200, "foo bar")
})

// Bad, instead make a local copy of the tree!
_ = func(method, path string, handler HandlerFunc) (*Route, error) {
r.Tree().Lock()
defer r.Tree().Unlock()
if r.Tree().Has(method, path) {
return r.Tree().Update(method, path, handler)
}
return r.Tree().Handle(method, path, handler)
}
}

// This example demonstrates how to create a custom middleware that cleans the request path and performs a manual
// lookup on the tree. If the cleaned path matches a registered route, the client is redirected to the valid path.
func ExampleRouter_Lookup() {
Expand Down
2 changes: 2 additions & 0 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func LoggerWithHandler(handler slog.Handler) MiddlewareFunc {
ipStr,
slog.Int("status", c.Writer().Status()),
slog.String("method", req.Method),
slog.String("host", c.Host()),
slog.String("path", c.Request().URL.String()),
slog.Duration("latency", roundLatency(latency)),
)
Expand All @@ -55,6 +56,7 @@ func LoggerWithHandler(handler slog.Handler) MiddlewareFunc {
ipStr,
slog.Int("status", c.Writer().Status()),
slog.String("method", req.Method),
slog.String("host", c.Host()),
slog.String("path", c.Request().URL.String()),
slog.Duration("latency", roundLatency(latency)),
slog.String("location", location),
Expand Down
8 changes: 4 additions & 4 deletions logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@ func TestLoggerWithHandler(t *testing.T) {
{
name: "should log info level",
req: httptest.NewRequest(http.MethodGet, "/success", nil),
want: "time=time level=INFO msg=192.0.2.1 status=200 method=GET path=/success latency=latency\n",
want: "time=time level=INFO msg=192.0.2.1 status=200 method=GET host=example.com path=/success latency=latency\n",
},
{
name: "should log error level",
req: httptest.NewRequest(http.MethodGet, "/failure", nil),
want: "time=time level=ERROR msg=192.0.2.1 status=500 method=GET path=/failure latency=latency\n",
want: "time=time level=ERROR msg=192.0.2.1 status=500 method=GET host=example.com path=/failure latency=latency\n",
},
{
name: "should log warn level",
req: httptest.NewRequest(http.MethodGet, "/foobar", nil),
want: "time=time level=WARN msg=192.0.2.1 status=404 method=GET path=/foobar latency=latency\n",
want: "time=time level=WARN msg=192.0.2.1 status=404 method=GET host=example.com path=/foobar latency=latency\n",
},
{
name: "should log debug level",
req: httptest.NewRequest(http.MethodGet, "/success/", nil),
want: "time=time level=DEBUG msg=192.0.2.1 status=301 method=GET path=/success/ latency=latency location=../success\n",
want: "time=time level=DEBUG msg=192.0.2.1 status=301 method=GET host=example.com path=/success/ latency=latency location=../success\n",
},
}

Expand Down
2 changes: 0 additions & 2 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,8 +1269,6 @@ func (t *Tree) allocateContext() *cTx {
skipNds: &skipNds,
tsrParams: &tsrParams,
// This is a read only value, no reset
tree: t,
// This is a read only value, no reset
fox: t.fox,
}
}
Expand Down

0 comments on commit 3c29ca8

Please sign in to comment.