Skip to content

Commit

Permalink
feat: support context params wrapping (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbenhoud authored Jun 14, 2023
1 parent 53b2a9b commit a66e6fa
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 28 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ func handle(c fox.Context) {
Fox itself implements the `http.Handler` interface which make easy to chain any compatible middleware before the router. Moreover, the router
provides convenient `fox.WrapF`, `fox.WrapH` and `fox.WrapM` adapter to be use with `http.Handler`.

The route parameters are being accessed by the wrapped handler through the `fox.Context` when the adapter `fox.WrapF` and `fox.WrapH` are used.

Wrapping an `http.Handler`
```go
articles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
20 changes: 17 additions & 3 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
package fox

import (
ctx "context"
netcontext "context"
"fmt"
"io"
"net/http"
Expand All @@ -25,7 +25,7 @@ type ContextCloser interface {
// (see Clone method).
type Context interface {
// Ctx returns the context associated with the current request.
Ctx() ctx.Context
Ctx() netcontext.Context
// Request returns the current *http.Request.
Request() *http.Request
// SetRequest sets the *http.Request.
Expand Down Expand Up @@ -186,7 +186,7 @@ func (c *context) TeeWriter(w io.Writer) {
}

// Ctx returns the context associated with the current request.
func (c *context) Ctx() ctx.Context {
func (c *context) Ctx() netcontext.Context {
return c.req.Context()
}

Expand Down Expand Up @@ -321,15 +321,29 @@ func (c *context) getQueries() url.Values {
}

// WrapF is an adapter for wrapping http.HandlerFunc and returns a HandlerFunc function.
// The route parameters are being accessed by the wrapped handler through the context.
func WrapF(f http.HandlerFunc) HandlerFunc {
return func(c Context) {
if len(c.Params()) > 0 {
ctx := netcontext.WithValue(c.Ctx(), paramsKey, c.Params().Clone())
f.ServeHTTP(c.Writer(), c.Request().WithContext(ctx))
return
}

f.ServeHTTP(c.Writer(), c.Request())
}
}

// WrapH is an adapter for wrapping http.Handler and returns a HandlerFunc function.
// The route parameters are being accessed by the wrapped handler through the context.
func WrapH(h http.Handler) HandlerFunc {
return func(c Context) {
if len(c.Params()) > 0 {
ctx := netcontext.WithValue(c.Ctx(), paramsKey, c.Params().Clone())
h.ServeHTTP(c.Writer(), c.Request().WithContext(ctx))
return
}

h.ServeHTTP(c.Writer(), c.Request())
}
}
Expand Down
132 changes: 113 additions & 19 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ import (
"compress/gzip"
netcontext "context"
"crypto/rand"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"io"
"log"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
)

func TestContext_QueryParams(t *testing.T) {
Expand Down Expand Up @@ -480,28 +481,121 @@ func TestContext_TeeWriter_h2(t *testing.T) {

func TestWrapF(t *testing.T) {
t.Parallel()
wrapped := WrapF(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("fox"))
})

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
_, c := NewTestContext(w, r)
wrapped(c)
assert.Equal(t, "fox", w.Body.String())
cases := []struct {
name string
handler func(p Params) http.HandlerFunc
params *Params
}{
{
name: "wrap handlerFunc without context params",
handler: func(expectedParams Params) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("fox"))
})
},
},
{
name: "wrap handlerFunc with context params",
handler: func(expectedParams Params) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("fox"))

p := ParamsFromContext(r.Context())

assert.Equal(t, expectedParams, p)
})
},
params: &Params{
{
Key: "foo",
Value: "bar",
},
},
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
_, c := NewTestContext(w, r)

params := make(Params, 0)
if tc.params != nil {
params = tc.params.Clone()
c.(*context).params = &params
}

WrapF(tc.handler(params))(c)

assert.Equal(t, "fox", w.Body.String())
})
}

}

func TestWrapH(t *testing.T) {
t.Parallel()
wrapped := WrapH(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("fox"))
}))

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
_, c := NewTestContext(w, r)
wrapped(c)
assert.Equal(t, "fox", w.Body.String())
cases := []struct {
name string
handler func(p Params) http.Handler
params *Params
}{
{
name: "wrap handler without context params",
handler: func(expectedParams Params) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("fox"))
})
},
},
{
name: "wrap handler with context params",
handler: func(expectedParams Params) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("fox"))

p := ParamsFromContext(r.Context())

assert.Equal(t, expectedParams, p)
})
},
params: &Params{
{
Key: "foo",
Value: "bar",
},
},
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
_, c := NewTestContext(w, r)

params := make(Params, 0)
if tc.params != nil {
params = tc.params.Clone()
c.(*context).params = &params
}

WrapH(tc.handler(params))(c)

assert.Equal(t, "fox", w.Body.String())
})
}
}

func TestWrapM(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (
ErrInvalidRoute = errors.New("invalid route")
ErrDiscardedResponseWriter = errors.New("discarded response writer")
ErrInvalidRedirectCode = errors.New("invalid redirect code")
ErrInvalidCtxParams = errors.New("unable to get params from context")
)

// RouteConflictError is a custom error type used to represent conflicts when
Expand Down
7 changes: 4 additions & 3 deletions fox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ package fox
import (
"bytes"
"fmt"
fuzz "github.com/google/gofuzz"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"log"
"math/rand"
Expand All @@ -23,6 +20,10 @@ import (
"sync/atomic"
"testing"
"time"

fuzz "github.com/google/gofuzz"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var emptyHandler = HandlerFunc(func(c Context) {})
Expand Down
5 changes: 3 additions & 2 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ package fox

import (
"bytes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestWrapFlushWriter(t *testing.T) {
Expand Down
23 changes: 23 additions & 0 deletions params.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

package fox

import netcontext "context"

// paramsKey is the key that holds the Params in a context.Context.
var paramsKey = struct{}{}

type Param struct {
Key string
Value string
Expand All @@ -21,9 +26,27 @@ func (p Params) Get(name string) string {
return ""
}

// Has checks whether the parameter exists by name.
func (p Params) Has(name string) bool {
for i := range p {
if p[i].Key == name {
return true
}
}

return false
}

// Clone make a copy of Params.
func (p Params) Clone() Params {
cloned := make(Params, len(p))
copy(cloned, p)
return cloned
}

// ParamsFromContext allows extracting params from the given context.
func ParamsFromContext(ctx netcontext.Context) Params {
p, _ := ctx.Value(paramsKey).(Params)

return p
}
74 changes: 73 additions & 1 deletion params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
package fox

import (
"github.com/stretchr/testify/assert"
netcontext "context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParams_Get(t *testing.T) {
Expand Down Expand Up @@ -39,3 +41,73 @@ func TestParams_Clone(t *testing.T) {
)
assert.Equal(t, params, params.Clone())
}

func TestParams_Has(t *testing.T) {
t.Parallel()

params := make(Params, 0, 2)
params = append(params,
Param{
Key: "foo",
Value: "bar",
},
Param{
Key: "john",
Value: "doe",
},
)

assert.True(t, params.Has("foo"))
assert.True(t, params.Has("john"))
assert.False(t, params.Has("jane"))
}

func TestParamsFromContext(t *testing.T) {
t.Parallel()

cases := []struct {
name string
ctx netcontext.Context
expectedParams Params
}{
{
name: "empty context",
ctx: netcontext.Background(),
expectedParams: nil,
},
{
name: "context with params",
ctx: func() netcontext.Context {
params := make(Params, 0, 2)
params = append(params,
Param{
Key: "foo",
Value: "bar",
},
)
return netcontext.WithValue(netcontext.Background(), paramsKey, params)
}(),
expectedParams: func() Params {
params := make(Params, 0, 2)
params = append(params,
Param{
Key: "foo",
Value: "bar",
},
)
return params
}(),
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

params := ParamsFromContext(tc.ctx)
assert.Equal(t, tc.expectedParams, params)
})
}
}

0 comments on commit a66e6fa

Please sign in to comment.