Skip to content

Commit

Permalink
Add request body streaming. Fixes #622 (#911)
Browse files Browse the repository at this point in the history
* Add request body streaming. Fixes #622
* Add test cases for StreamRequestBody

Co-authored-by: Kiyon <kiyonlin@163.com>
Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
Co-authored-by: Fiber
  • Loading branch information
3 people authored Feb 6, 2021
1 parent fbe6a2d commit 0956208
Show file tree
Hide file tree
Showing 6 changed files with 501 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ tags
*.fasthttp.gz
*.fasthttp.br
.idea
.DS_Store
148 changes: 148 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ func (resp *Response) BodyInflate() ([]byte, error) {
return inflateData(resp.Body())
}

func (ctx *RequestCtx) RequestBodyStream() io.Reader {
return ctx.Request.bodyStream
}

func inflateData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteInflate(&bb, p)
Expand Down Expand Up @@ -1017,6 +1021,53 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
return req.ContinueReadBody(r, maxBodySize, preParseMultipartForm)
}

func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error {
if getOnly && !req.Header.IsGet() {
return ErrGetOnly
}

if req.MayContinue() {
// 'Expect: 100-continue' header found. Let the caller deciding
// whether to read request body or
// to return StatusExpectationFailed.
return nil
}

var err error
contentLength := req.Header.realContentLength()
if contentLength > 0 {
if preParseMultipartForm {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary())
if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}

if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
req.Header.SetContentLength(0)
return nil
}

bodyBuf := req.bodyBuffer()
bodyBuf.Reset()

req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)

return nil
}

// MayContinue returns true if the request contains
// 'Expect: 100-continue' header.
//
Expand Down Expand Up @@ -1081,6 +1132,73 @@ func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int, preParseM
return nil
}

// ContinueReadBody reads request body if request header contains
// 'Expect: 100-continue'.
//
// The caller must send StatusContinue response before calling this method.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, preParseMultipartForm ...bool) error {
var err error
contentLength := req.Header.realContentLength()
if contentLength > 0 {
if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary())
if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}

if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
req.Header.SetContentLength(0)
return nil
}

bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B)
bodyBufLen := maxBodySize
if contentLength < maxBodySize {
bodyBufLen = cap(bodyBuf.B)
}
if err != nil {
if err == ErrBodyTooLarge {
req.Header.SetContentLength(contentLength)
req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
return nil
}
if err == errChunkedStream {
req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, -1)
return nil
}
req.Reset()
return err
}

req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
req.Header.SetContentLength(len(bodyBuf.B))
return nil
}

// Read reads response (including body) from the given r.
//
// io.EOF is returned if r is closed before reading the first header byte.
Expand Down Expand Up @@ -1815,6 +1933,36 @@ func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (
return readBodyIdentity(r, maxBodySize, dst)
}

var errChunkedStream = errors.New("chunked stream")

func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) {
dst = dst[:0]
switch {
case contentLength >= 0 && maxBodySize >= contentLength:
readN := maxBodySize
if contentLength > 8*1024 {
readN = 8 * 1024
}
b, err = appendBodyFixedSize(r, dst, readN)
case contentLength == -1:
// handled in requestStream.Read()
err = errChunkedStream
default:
readN := maxBodySize
if contentLength > 8*1024 {
readN = 8 * 1024
}
b, err = readBodyIdentity(r, readN, dst)
}
if err != nil {
return b, err
}
if contentLength > maxBodySize {
return b, ErrBodyTooLarge
}
return b, nil
}

func readBodyIdentity(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) {
dst = dst[:cap(dst)]
if len(dst) == 0 {
Expand Down
23 changes: 21 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,11 @@ type Server struct {
// which will close it when needed.
KeepHijackedConns bool

// StreamRequestBody enables request body streaming,
// and calls the handler sooner when given body is
// larger then the current limit.
StreamRequestBody bool

tlsConfig *tls.Config
nextProtos map[string]ServeHandler

Expand Down Expand Up @@ -2075,7 +2080,11 @@ func (s *Server) serveConn(c net.Conn) (err error) {
}
}
//read body
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
if s.StreamRequestBody {
err = ctx.Request.readBodyStream(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
} else {
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
}
}

if err == nil {
Expand Down Expand Up @@ -2150,7 +2159,11 @@ func (s *Server) serveConn(c net.Conn) (err error) {
br = acquireReader(ctx)
}

err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
if s.StreamRequestBody {
err = ctx.Request.ContinueReadBodyStream(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
} else {
err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
}
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
Expand Down Expand Up @@ -2279,6 +2292,12 @@ func (s *Server) serveConn(c net.Conn) (err error) {
break
}

if ctx.Request.bodyStream != nil {
if rs, ok := ctx.Request.bodyStream.(*requestStream); ok {
releaseRequestStream(rs)
}
}

s.setState(c, StateIdle)

if atomic.LoadInt32(&s.stop) == 1 {
Expand Down
111 changes: 111 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3347,6 +3347,117 @@ func TestMaxBodySizePerRequest(t *testing.T) {
}
}

func TestStreamRequestBody(t *testing.T) {
t.Parallel()

part1 := strings.Repeat("1", 1<<10)
part2 := strings.Repeat("2", 1<<20-1<<10)
contentLength := len(part1) + len(part2)
next := make(chan struct{})

s := &Server{
Handler: func(ctx *RequestCtx) {
checkReader(t, ctx.RequestBodyStream(), part1)
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
DisableKeepalive: true,
StreamRequestBody: true,
}

pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
//write headers and part1 body
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil {
t.Error(err)
}

ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()

select {
case <-next:
case <-time.After(500 * time.Millisecond):
t.Fatal("part1 timeout")
}

if _, err := cc.Write([]byte(part2)); err != nil {
t.Error(err)
}

select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("part2 timeout")
}
}

func TestStreamRequestBodyExceedMaxSize(t *testing.T) {
t.Parallel()

part1 := strings.Repeat("1", 1<<18)
part2 := strings.Repeat("2", 1<<20-1<<18)
contentLength := len(part1) + len(part2)
next := make(chan struct{})

s := &Server{
Handler: func(ctx *RequestCtx) {
checkReader(t, ctx.RequestBodyStream(), part1)
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
DisableKeepalive: true,
StreamRequestBody: true,
MaxRequestBodySize: 1,
}

pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
//write headers and part1 body
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil {
t.Error(err)
}

ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()

select {
case <-next:
case <-time.After(500 * time.Millisecond):
t.Fatal("part1 timeout")
}

if _, err := cc.Write([]byte(part2)); err != nil {
t.Error(err)
}

select {
case err := <-ch:
if err != nil {
t.Error(err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("part2 timeout")
}
}

func checkReader(t *testing.T, r io.Reader, expected string) {
b := make([]byte, len(expected))
if _, err := io.ReadFull(r, b); err != nil {
t.Fatalf("Unexpected error from reader: %s", err)
}
if string(b) != expected {
t.Fatal("incorrect request body")
}
}

func TestMaxReadTimeoutPerRequest(t *testing.T) {
t.Parallel()

Expand Down
Loading

0 comments on commit 0956208

Please sign in to comment.