From 938a9fb94e41285a443b0882dbc46f2a4c6ed484 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 15 Jan 2025 09:45:14 -0800 Subject: [PATCH] internal/http3: add request/response body transfer For golang/go#70914 Change-Id: I372458214fe73f8156e0ec291168b043c10221e6 Reviewed-on: https://go-review.googlesource.com/c/net/+/644915 Reviewed-by: Brad Fitzpatrick LUCI-TryBot-Result: Go LUCI Auto-Submit: Damien Neil Reviewed-by: Jonathan Amsterdam --- internal/http3/body.go | 142 +++++++++++ internal/http3/http3_test.go | 7 + internal/http3/roundtrip.go | 124 +++++++++- internal/http3/roundtrip_test.go | 390 +++++++++++++++++++++++++++++++ internal/http3/transport_test.go | 56 +++++ 5 files changed, 707 insertions(+), 12 deletions(-) create mode 100644 internal/http3/body.go diff --git a/internal/http3/body.go b/internal/http3/body.go new file mode 100644 index 000000000..cdde482ef --- /dev/null +++ b/internal/http3/body.go @@ -0,0 +1,142 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package http3 + +import ( + "errors" + "fmt" + "io" + "sync" +) + +// A bodyWriter writes a request or response body to a stream +// as a series of DATA frames. +type bodyWriter struct { + st *stream + remain int64 // -1 when content-length is not known + flush bool // flush the stream after every write + name string // "request" or "response" +} + +func (w *bodyWriter) Write(p []byte) (n int, err error) { + if w.remain >= 0 && int64(len(p)) > w.remain { + return 0, &streamError{ + code: errH3InternalError, + message: w.name + " body longer than specified content length", + } + } + w.st.writeVarint(int64(frameTypeData)) + w.st.writeVarint(int64(len(p))) + n, err = w.st.Write(p) + if w.remain >= 0 { + w.remain -= int64(n) + } + if w.flush && err == nil { + err = w.st.Flush() + } + if err != nil { + err = fmt.Errorf("writing %v body: %w", w.name, err) + } + return n, err +} + +func (w *bodyWriter) Close() error { + if w.remain > 0 { + return errors.New(w.name + " body shorter than specified content length") + } + return nil +} + +// A bodyReader reads a request or response body from a stream. +type bodyReader struct { + st *stream + + mu sync.Mutex + remain int64 + err error +} + +func (r *bodyReader) Read(p []byte) (n int, err error) { + // The HTTP/1 and HTTP/2 implementations both permit concurrent reads from a body, + // in the sense that the race detector won't complain. + // Use a mutex here to provide the same behavior. + r.mu.Lock() + defer r.mu.Unlock() + if r.err != nil { + return 0, r.err + } + defer func() { + if err != nil { + r.err = err + } + }() + if r.st.lim == 0 { + // We've finished reading the previous DATA frame, so end it. + if err := r.st.endFrame(); err != nil { + return 0, err + } + } + // Read the next DATA frame header, + // if we aren't already in the middle of one. + for r.st.lim < 0 { + ftype, err := r.st.readFrameHeader() + if err == io.EOF && r.remain > 0 { + return 0, &streamError{ + code: errH3MessageError, + message: "body shorter than content-length", + } + } + if err != nil { + return 0, err + } + switch ftype { + case frameTypeData: + if r.remain >= 0 && r.st.lim > r.remain { + return 0, &streamError{ + code: errH3MessageError, + message: "body longer than content-length", + } + } + // Fall out of the loop and process the frame body below. + case frameTypeHeaders: + // This HEADERS frame contains the message trailers. + if r.remain > 0 { + return 0, &streamError{ + code: errH3MessageError, + message: "body shorter than content-length", + } + } + // TODO: Fill in Request.Trailer. + if err := r.st.discardFrame(); err != nil { + return 0, err + } + return 0, io.EOF + default: + if err := r.st.discardUnknownFrame(ftype); err != nil { + return 0, err + } + } + } + // We are now reading the content of a DATA frame. + // Fill the read buffer or read to the end of the frame, + // whichever comes first. + if int64(len(p)) > r.st.lim { + p = p[:r.st.lim] + } + n, err = r.st.Read(p) + if r.remain > 0 { + r.remain -= int64(n) + } + return n, err +} + +func (r *bodyReader) Close() error { + // Unlike the HTTP/1 and HTTP/2 body readers (at the time of this comment being written), + // calling Close concurrently with Read will interrupt the read. + r.st.stream.CloseRead() + return nil +} diff --git a/internal/http3/http3_test.go b/internal/http3/http3_test.go index 281c0cd54..f490ad3f0 100644 --- a/internal/http3/http3_test.go +++ b/internal/http3/http3_test.go @@ -73,3 +73,10 @@ func unhex(s string) []byte { } return b } + +// testReader implements io.Reader. +type testReader struct { + readFunc func([]byte) (int, error) +} + +func (r testReader) Read(p []byte) (n int, err error) { return r.readFunc(p) } diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 9042c15bf..b24a30308 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -7,13 +7,60 @@ package http3 import ( + "errors" "io" "net/http" "strconv" + "sync" "golang.org/x/net/internal/httpcommon" ) +type roundTripState struct { + cc *ClientConn + st *stream + + // Request body, provided by the caller. + onceCloseReqBody sync.Once + reqBody io.ReadCloser + + reqBodyWriter bodyWriter + + // Response.Body, provided to the caller. + respBody bodyReader + + errOnce sync.Once + err error +} + +// abort terminates the RoundTrip. +// It returns the first fatal error encountered by the RoundTrip call. +func (rt *roundTripState) abort(err error) error { + rt.errOnce.Do(func() { + rt.err = err + switch e := err.(type) { + case *connectionError: + rt.cc.abort(e) + case *streamError: + rt.st.stream.CloseRead() + rt.st.stream.Reset(uint64(e.code)) + default: + rt.st.stream.CloseRead() + rt.st.stream.Reset(uint64(errH3NoError)) + } + }) + return rt.err +} + +// closeReqBody closes the Request.Body, at most once. +func (rt *roundTripState) closeReqBody() { + if rt.reqBody != nil { + rt.onceCloseReqBody.Do(func() { + rt.reqBody.Close() + }) + } +} + // RoundTrip sends a request on the connection. func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) { // Each request gets its own QUIC stream. @@ -21,17 +68,13 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) if err != nil { return nil, err } + rt := &roundTripState{ + cc: cc, + st: st, + } defer func() { - switch e := err.(type) { - case nil: - case *connectionError: - cc.abort(e) - case *streamError: - st.stream.CloseRead() - st.stream.Reset(uint64(e.code)) - default: - st.stream.CloseRead() - st.stream.Reset(uint64(errH3NoError)) + if err != nil { + err = rt.abort(err) } }() @@ -64,7 +107,13 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) } if encr.HasBody { - // TODO: Send the request body. + // TODO: Defer sending the request body when "Expect: 100-continue" is set. + rt.reqBody = req.Body + rt.reqBodyWriter.st = st + rt.reqBodyWriter.remain = httpcommon.ActualContentLength(req) + rt.reqBodyWriter.flush = true + rt.reqBodyWriter.name = "request" + go copyRequestBody(rt) } // Read the response headers. @@ -91,6 +140,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) if err != nil { return nil, err } + rt.respBody.st = st + rt.respBody.remain = contentLength resp := &http.Response{ Proto: "HTTP/3.0", ProtoMajor: 3, @@ -98,7 +149,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) StatusCode: statusCode, Status: strconv.Itoa(statusCode) + " " + http.StatusText(statusCode), ContentLength: contentLength, - Body: io.NopCloser(nil), // TODO: read the response body + Body: (*transportResponseBody)(rt), } // TODO: Automatic Content-Type: gzip decoding. return resp, nil @@ -114,6 +165,55 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) } } +func copyRequestBody(rt *roundTripState) { + defer rt.closeReqBody() + _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody) + if closeErr := rt.reqBodyWriter.Close(); err == nil { + err = closeErr + } + if err != nil { + // Something went wrong writing the body. + rt.abort(err) + } else { + // We wrote the whole body. + rt.st.stream.CloseWrite() + } +} + +// transportResponseBody is the Response.Body returned by RoundTrip. +type transportResponseBody roundTripState + +// Read is Response.Body.Read. +func (b *transportResponseBody) Read(p []byte) (n int, err error) { + return b.respBody.Read(p) +} + +var errRespBodyClosed = errors.New("response body closed") + +// Close is Response.Body.Close. +// Closing the response body is how the caller signals that they're done with a request. +func (b *transportResponseBody) Close() error { + rt := (*roundTripState)(b) + // Close the request body, which should wake up copyRequestBody if it's + // currently blocked reading the body. + rt.closeReqBody() + // Close the request stream, since we're done with the request. + // Reset closes the sending half of the stream. + rt.st.stream.Reset(uint64(errH3NoError)) + // respBody.Close is responsible for closing the receiving half. + err := rt.respBody.Close() + if err == nil { + err = errRespBodyClosed + } + err = rt.abort(err) + if err == errRespBodyClosed { + // No other errors occurred before closing Response.Body, + // so consider this a successful request. + return nil + } + return err +} + func parseResponseContentLength(method string, statusCode int, h http.Header) (int64, error) { clens := h["Content-Length"] if len(clens) == 0 { diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go index 34397c07f..533b750a5 100644 --- a/internal/http3/roundtrip_test.go +++ b/internal/http3/roundtrip_test.go @@ -7,8 +7,14 @@ package http3 import ( + "bytes" + "errors" + "io" "net/http" "testing" + "testing/synctest" + + "golang.org/x/net/quic" ) func TestRoundTripSimple(t *testing.T) { @@ -230,3 +236,387 @@ func TestRoundTripCrumbledCookiesInResponse(t *testing.T) { }) }) } + +func TestRoundTripResponseBody(t *testing.T) { + // These tests consist of a series of steps, + // where each step is either something arriving on the response stream + // or the client reading from the request body. + type ( + // HEADERS frame arrives on the response stream (headers or trailers). + receiveHeaders http.Header + // DATA frame header arrives on the response stream. + receiveDataHeader struct { + size int64 + } + // DATA frame content arrives on the response stream. + receiveData struct { + size int64 + } + // Some other frame arrives on the response stream. + receiveFrame struct { + ftype frameType + data []byte + } + // Response stream closed, ending the body. + receiveEOF struct{} + // Client reads from Response.Body. + wantBody struct { + size int64 + eof bool + } + wantError struct{} + ) + for _, test := range []struct { + name string + respHeader http.Header + steps []any + wantError bool + }{{ + name: "no content length", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + }, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + receiveEOF{}, + wantBody{size: 10, eof: true}, + }, + }, { + name: "valid content length", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + "content-length": []string{"10"}, + }, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + receiveEOF{}, + wantBody{size: 10, eof: true}, + }, + }, { + name: "data frame exceeds content length", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + "content-length": []string{"5"}, + }, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantError{}, + }, + }, { + name: "data frame after all content read", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + "content-length": []string{"5"}, + }, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + wantBody{size: 5}, + receiveDataHeader{size: 1}, + receiveData{size: 1}, + wantError{}, + }, + }, { + name: "content length too long", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + "content-length": []string{"10"}, + }, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + receiveEOF{}, + wantBody{size: 5}, + wantError{}, + }, + }, { + name: "stream ended by trailers", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + }, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + receiveHeaders{ + "x-trailer": []string{"value"}, + }, + wantBody{size: 5, eof: true}, + }, + }, { + name: "trailers and content length too long", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + "content-length": []string{"10"}, + }, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + wantBody{size: 5}, + receiveHeaders{ + "x-trailer": []string{"value"}, + }, + wantError{}, + }, + }, { + name: "unknown frame before headers", + steps: []any{ + receiveFrame{ + ftype: 0x1f + 0x21, // reserved frame type + data: []byte{1, 2, 3, 4}, + }, + receiveHeaders{ + ":status": []string{"200"}, + }, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantBody{size: 10}, + }, + }, { + name: "unknown frame after headers", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + }, + receiveFrame{ + ftype: 0x1f + 0x21, // reserved frame type + data: []byte{1, 2, 3, 4}, + }, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantBody{size: 10}, + }, + }, { + name: "invalid frame", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + }, + receiveFrame{ + ftype: frameTypeSettings, // not a valid frame on this stream + data: []byte{1, 2, 3, 4}, + }, + wantError{}, + }, + }, { + name: "data frame consumed by several reads", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + }, + receiveDataHeader{size: 16}, + receiveData{size: 16}, + wantBody{size: 2}, + wantBody{size: 4}, + wantBody{size: 8}, + wantBody{size: 2}, + }, + }, { + name: "read multiple frames", + steps: []any{ + receiveHeaders{ + ":status": []string{"200"}, + }, + receiveDataHeader{size: 2}, + receiveData{size: 2}, + receiveDataHeader{size: 4}, + receiveData{size: 4}, + receiveDataHeader{size: 8}, + receiveData{size: 8}, + wantBody{size: 2}, + wantBody{size: 4}, + wantBody{size: 8}, + }, + }} { + runSynctestSubtest(t, test.name, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + + var ( + bytesSent int + bytesReceived int + ) + for _, step := range test.steps { + switch step := step.(type) { + case receiveHeaders: + st.writeHeaders(http.Header(step)) + case receiveDataHeader: + t.Logf("receive DATA frame header: size=%v", step.size) + st.writeVarint(int64(frameTypeData)) + st.writeVarint(step.size) + st.Flush() + case receiveData: + t.Logf("receive DATA frame content: size=%v", step.size) + for range step.size { + st.stream.stream.WriteByte(byte(bytesSent)) + bytesSent++ + } + st.Flush() + case receiveFrame: + st.writeVarint(int64(step.ftype)) + st.writeVarint(int64(len(step.data))) + st.Write(step.data) + st.Flush() + case receiveEOF: + t.Logf("receive EOF on request stream") + st.stream.stream.CloseWrite() + case wantBody: + t.Logf("read %v bytes from response body", step.size) + want := make([]byte, step.size) + for i := range want { + want[i] = byte(bytesReceived) + bytesReceived++ + } + got := make([]byte, step.size) + n, err := rt.response().Body.Read(got) + got = got[:n] + if !bytes.Equal(got, want) { + t.Errorf("resp.Body.Read:") + t.Errorf(" got: {%x}", got) + t.Fatalf(" want: {%x}", want) + } + if err != nil { + if step.eof && err == io.EOF { + continue + } + t.Fatalf("resp.Body.Read: unexpected error %v", err) + } + if step.eof { + if n, err := rt.response().Body.Read([]byte{0}); n != 0 || err != io.EOF { + t.Fatalf("resp.Body.Read() = %v, %v; want io.EOF", n, err) + } + } + case wantError: + if n, err := rt.response().Body.Read([]byte{0}); n != 0 || err == nil || err == io.EOF { + t.Fatalf("resp.Body.Read() = %v, %v; want error", n, err) + } + default: + t.Fatalf("unknown test step %T", step) + } + } + }) + } +} + +func TestRoundTripRequestBodySent(t *testing.T) { + runSynctest(t, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + bodyr, bodyw := io.Pipe() + + req, _ := http.NewRequest("GET", "https://example.tld/", bodyr) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + + bodyw.Write([]byte{0, 1, 2, 3, 4}) + st.wantData([]byte{0, 1, 2, 3, 4}) + + bodyw.Write([]byte{5, 6, 7}) + st.wantData([]byte{5, 6, 7}) + + bodyw.Close() + st.wantClosed("request body sent") + + st.writeHeaders(http.Header{ + ":status": []string{"200"}, + }) + rt.wantStatus(200) + rt.response().Body.Close() + }) +} + +func TestRoundTripRequestBodyErrors(t *testing.T) { + for _, test := range []struct { + name string + body io.Reader + contentLength int64 + }{{ + name: "too short", + contentLength: 10, + body: bytes.NewReader([]byte{0, 1, 2, 3, 4}), + }, { + name: "too long", + contentLength: 5, + body: bytes.NewReader([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + }, { + name: "read error", + body: io.MultiReader( + bytes.NewReader([]byte{0, 1, 2, 3, 4}), + &testReader{ + readFunc: func([]byte) (int, error) { + return 0, errors.New("read error") + }, + }, + ), + }} { + runSynctestSubtest(t, test.name, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", test.body) + req.ContentLength = test.contentLength + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + + // The Transport should send some number of frames before detecting an + // error in the request body and aborting the request. + synctest.Wait() + for { + _, err := st.readFrameHeader() + if err != nil { + var code quic.StreamErrorCode + if !errors.As(err, &code) { + t.Fatalf("request stream closed with error %v: want QUIC stream error", err) + } + break + } + if err := st.discardFrame(); err != nil { + t.Fatalf("discardFrame: %v", err) + } + } + + // RoundTrip returns with an error. + rt.wantError("request fails due to body error") + }) + } +} + +func TestRoundTripRequestBodyErrorAfterHeaders(t *testing.T) { + runSynctest(t, func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + bodyr, bodyw := io.Pipe() + req, _ := http.NewRequest("GET", "https://example.tld/", bodyr) + req.ContentLength = 10 + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + + // Server sends response headers, and RoundTrip returns. + // The request body hasn't been sent yet. + st.wantHeaders(nil) + st.writeHeaders(http.Header{ + ":status": []string{"200"}, + }) + rt.wantStatus(200) + + // Write too many bytes to the request body, triggering a request error. + bodyw.Write(make([]byte, req.ContentLength+1)) + + //io.Copy(io.Discard, st) + st.wantError(quic.StreamErrorCode(errH3InternalError)) + + if err := rt.response().Body.Close(); err == nil { + t.Fatalf("Response.Body.Close() = %v, want error", err) + } + }) +} diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go index a61c9a661..dd034c658 100644 --- a/internal/http3/transport_test.go +++ b/internal/http3/transport_test.go @@ -7,9 +7,11 @@ package http3 import ( + "bytes" "context" "errors" "fmt" + "io" "maps" "net/http" "reflect" @@ -284,6 +286,9 @@ func (ts *testQUICStream) wantHeaders(want http.Header) { } if want == nil { + if err := ts.discardFrame(); err != nil { + ts.t.Fatalf("discardFrame: %v", err) + } return } @@ -296,6 +301,9 @@ func (ts *testQUICStream) wantHeaders(want http.Header) { if diff := diffHeaders(got, want); diff != "" { ts.t.Fatalf("unexpected response headers:\n%v", diff) } + if err := ts.endFrame(); err != nil { + ts.t.Fatalf("endFrame: %v", err) + } } func (ts *testQUICStream) encodeHeaders(h http.Header) []byte { @@ -323,6 +331,53 @@ func (ts *testQUICStream) writeHeaders(h http.Header) { } } +func (ts *testQUICStream) wantData(want []byte) { + ts.t.Helper() + synctest.Wait() + ftype, err := ts.readFrameHeader() + if err != nil { + ts.t.Fatalf("want DATA frame, got error: %v", err) + } + if ftype != frameTypeData { + ts.t.Fatalf("want DATA frame, got: %v", ftype) + } + got, err := ts.readFrameData() + if err != nil { + ts.t.Fatalf("error reading DATA frame: %v", err) + } + if !bytes.Equal(got, want) { + ts.t.Fatalf("got data: {%x}, want {%x}", got, want) + } + if err := ts.endFrame(); err != nil { + ts.t.Fatalf("endFrame: %v", err) + } +} + +func (ts *testQUICStream) wantClosed(reason string) { + ts.t.Helper() + synctest.Wait() + ftype, err := ts.readFrameHeader() + if err != io.EOF { + ts.t.Fatalf("%v: want io.EOF, got %v %v", reason, ftype, err) + } +} + +func (ts *testQUICStream) wantError(want quic.StreamErrorCode) { + ts.t.Helper() + synctest.Wait() + _, err := ts.stream.stream.ReadByte() + if err == nil { + ts.t.Fatalf("successfully read from stream; want stream error code %v", want) + } + var got quic.StreamErrorCode + if !errors.As(err, &got) { + ts.t.Fatalf("stream error = %v; want %v", err, want) + } + if got != want { + ts.t.Fatalf("stream error code = %v; want %v", got, want) + } +} + func (ts *testQUICStream) writePushPromise(pushID int64, h http.Header) { ts.t.Helper() headers := ts.encodeHeaders(h) @@ -453,6 +508,7 @@ func (rt *testRoundTrip) err() error { func (rt *testRoundTrip) wantError(reason string) { rt.t.Helper() + synctest.Wait() if !rt.done() { rt.t.Fatalf("%v: RoundTrip is not done; want it to have returned an error", reason) }