Skip to content

Commit

Permalink
feat: support custom formvalue function (#1453)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou authored Dec 25, 2022
1 parent 2a572e0 commit b788e66
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 14 deletions.
68 changes: 54 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,12 @@ type Server struct {
// instead.
TLSConfig *tls.Config

// FormValueFunc, which is used by RequestCtx.FormValue and support for customising
// the behaviour of the RequestCtx.FormValue function.
//
// NetHttpFormValueFunc gives a FormValueFunc func implementation that is consistent with net/http.
FormValueFunc FormValueFunc

nextProtos map[string]ServeHandler

concurrency uint32
Expand Down Expand Up @@ -604,6 +610,7 @@ type RequestCtx struct {

hijackHandler HijackHandler
hijackNoResponse bool
formValueFunc FormValueFunc
}

// HijackHandler must process the hijacked connection c.
Expand Down Expand Up @@ -1108,23 +1115,54 @@ func SaveMultipartFile(fh *multipart.FileHeader, path string) (err error) {
//
// The returned value is valid until your request handler returns.
func (ctx *RequestCtx) FormValue(key string) []byte {
v := ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
if ctx.formValueFunc != nil {
return ctx.formValueFunc(ctx, key)
}
v = ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
return defaultFormValue(ctx, key)
}

type FormValueFunc func(*RequestCtx, string) []byte

var (
defaultFormValue = func(ctx *RequestCtx, key string) []byte {
v := ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
}
v = ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])
}
}
return nil
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])

// NetHttpFormValueFunc gives consistent behavior with net/http. POST and PUT body parameters take precedence over URL query string values.
NetHttpFormValueFunc = func(ctx *RequestCtx, key string) []byte {
v := ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])
}
}
v = ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
}
return nil
}
return nil
}
)

// IsGet returns true if request method is GET.
func (ctx *RequestCtx) IsGet() bool {
Expand Down Expand Up @@ -2638,7 +2676,9 @@ func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
} else {
ctx = v.(*RequestCtx)
}

if s.FormValueFunc != nil {
ctx.formValueFunc = s.FormValueFunc
}
ctx.c = c

return ctx
Expand Down
36 changes: 36 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,20 @@ func TestRequestCtxFormValue(t *testing.T) {
}
}

func TestSetStandardFormValueFunc(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
req.SetRequestURI("/foo/bar?aaa=bbb")
req.SetBodyString("aaa=port")
req.Header.SetContentType("application/x-www-form-urlencoded")
ctx.Init(&req, nil, nil)
ctx.formValueFunc = NetHttpFormValueFunc
v := ctx.FormValue("aaa")
if string(v) != "port" {
t.Fatalf("unexpected value %q. Expecting %q", v, "port")
}
}
func TestRequestCtxUserValue(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -3287,6 +3301,28 @@ func TestServeConnSingleRequest(t *testing.T) {
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
}

func TestServerSetFormValueFunc(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa", ctx.FormValue("aaa"))
},
FormValueFunc: func(ctx *RequestCtx, s string) []byte {
return []byte(s)
},
}

rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")

if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}

br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "aaa")
}

func TestServeConnMultiRequests(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit b788e66

Please sign in to comment.