From 22fb09708d28ddd2da5b942eb2c0cb1b61189312 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Sat, 11 Sep 2021 18:41:19 -0700 Subject: [PATCH] [internal-branch.go1.17-vendor] http2: avoid blocking while holding ClientConn.mu Operations which examine the state of a ClientConn--notably, the connection pool's check to see if a conn is available to take a new request--need to acquire mu. Blocking while holding mu, such as when writing to the network, blocks these operations. Remove blocking operations from the mutex. Perform network writes with only ClientConn.wmu held. Clarify that wmu guards the per-conn HPACK encoder and buffer. Add a new mutex guarding request creation, covering the critical section starting with allocating a new stream ID and continuing until the stream is created. Fix a locking issue where trailers were written from the HPACK buffer with only wmu held, but headers were encoded into the buffer with only mu held. (Now both encoding and writes occur with wmu held.) Updates golang/go#49077 Change-Id: Ibb313424ed2f32c1aeac4645b76aedf227b597a3 Reviewed-on: https://go-review.googlesource.com/c/net/+/349594 Trust: Damien Neil Run-TryBot: Damien Neil TryBot-Result: Go Bot Reviewed-by: Brad Fitzpatrick Reviewed-on: https://go-review.googlesource.com/c/net/+/357677 Reviewed-by: Dmitri Shuralyov --- transport.go | 171 +++++++++++++++++++++----------- transport_test.go | 243 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 358 insertions(+), 56 deletions(-) diff --git a/transport.go b/transport.go index 74c76da..2e28944 100644 --- a/transport.go +++ b/transport.go @@ -264,22 +264,29 @@ type ClientConn struct { nextStreamID uint32 pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams pings map[[8]byte]chan struct{} // in flight ping data to notification channel - bw *bufio.Writer br *bufio.Reader - fr *Framer lastActive time.Time lastIdle time.Time // time last idle - // Settings from peer: (also guarded by mu) + // Settings from peer: (also guarded by wmu) maxFrameSize uint32 maxConcurrentStreams uint32 peerMaxHeaderListSize uint64 initialWindowSize uint32 + // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. + // Write to reqHeaderMu to lock it, read from it to unlock. + // Lock reqmu BEFORE mu or wmu. + reqHeaderMu chan struct{} + + // wmu is held while writing. + // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. + // Only acquire both at the same time when changing peer settings. + wmu sync.Mutex + bw *bufio.Writer + fr *Framer + werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder - - wmu sync.Mutex // held while writing; acquire AFTER mu if holding both - werr error // first write error that has occurred } // clientStream is the state for a single HTTP/2 stream. One of these @@ -398,10 +405,11 @@ func (cs *clientStream) abortRequestBodyWrite(err error) { cc.mu.Lock() if cs.stopReqBody == nil { cs.stopReqBody = err - if cs.req.Body != nil { - cs.req.Body.Close() - } cc.cond.Broadcast() + // Close the body after releasing the mutex, in case it blocks. + if body := cs.req.Body; body != nil { + defer body.Close() + } } cc.mu.Unlock() } @@ -666,6 +674,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro singleUse: singleUse, wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), + reqHeaderMu: make(chan struct{}, 1), } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d @@ -894,15 +903,18 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { func (cc *ClientConn) sendGoAway() error { cc.mu.Lock() - defer cc.mu.Unlock() - cc.wmu.Lock() - defer cc.wmu.Unlock() - if cc.closing { + closing := cc.closing + cc.closing = true + maxStreamID := cc.nextStreamID + cc.mu.Unlock() + if closing { // GOAWAY sent already return nil } + + cc.wmu.Lock() + defer cc.wmu.Unlock() // Send a graceful shutdown frame to server - maxStreamID := cc.nextStreamID if err := cc.fr.WriteGoAway(maxStreamID, ErrCodeNo, nil); err != nil { return err } @@ -910,7 +922,6 @@ func (cc *ClientConn) sendGoAway() error { return err } // Prevent new requests - cc.closing = true return nil } @@ -918,17 +929,22 @@ func (cc *ClientConn) sendGoAway() error { // err is sent to streams. func (cc *ClientConn) closeForError(err error) error { cc.mu.Lock() - defer cc.cond.Broadcast() - defer cc.mu.Unlock() - for id, cs := range cc.streams { + streams := cc.streams + cc.streams = nil + cc.closed = true + cc.mu.Unlock() + + for _, cs := range streams { select { case cs.resc <- resAndError{err: err}: default: } cs.bufPipe.CloseWithError(err) - delete(cc.streams, id) } - cc.closed = true + + cc.mu.Lock() + defer cc.cond.Broadcast() + defer cc.mu.Unlock() return cc.tconn.Close() } @@ -1013,6 +1029,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { } func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) { + ctx := req.Context() if err := checkConnHeaders(req); err != nil { return nil, false, err } @@ -1026,6 +1043,26 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf } hasTrailers := trailers != "" + // Acquire the new-request lock by writing to reqHeaderMu. + // This lock guards the critical section covering allocating a new stream ID + // (requires mu) and creating the stream (requires wmu). + if cc.reqHeaderMu == nil { + panic("RoundTrip on initialized ClientConn") // for tests + } + select { + case cc.reqHeaderMu <- struct{}{}: + case <-req.Cancel: + return nil, false, errRequestCanceled + case <-ctx.Done(): + return nil, false, ctx.Err() + } + reqHeaderMuNeedsUnlock := true + defer func() { + if reqHeaderMuNeedsUnlock { + <-cc.reqHeaderMu + } + }() + cc.mu.Lock() if err := cc.awaitOpenSlotForRequest(req); err != nil { cc.mu.Unlock() @@ -1057,22 +1094,24 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf requestedGzip = true } + cs := cc.newStream() + cs.req = req + cs.trace = httptrace.ContextClientTrace(req.Context()) + cs.requestedGzip = requestedGzip + bodyWriter := cc.t.getBodyWriterState(cs, body) + cs.on100 = bodyWriter.on100 + cc.mu.Unlock() + // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is // sent by writeRequestBody below, along with any Trailers, // again in form HEADERS{1}, CONTINUATION{0,}) + cc.wmu.Lock() hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) if err != nil { - cc.mu.Unlock() + cc.wmu.Unlock() return nil, false, err } - cs := cc.newStream() - cs.req = req - cs.trace = httptrace.ContextClientTrace(req.Context()) - cs.requestedGzip = requestedGzip - bodyWriter := cc.t.getBodyWriterState(cs, body) - cs.on100 = bodyWriter.on100 - defer func() { cc.wmu.Lock() werr := cc.werr @@ -1082,24 +1121,24 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf } }() - cc.wmu.Lock() endStream := !hasBody && !hasTrailers - werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) + err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) cc.wmu.Unlock() + <-cc.reqHeaderMu // release the new-request lock + reqHeaderMuNeedsUnlock = false traceWroteHeaders(cs.trace) - cc.mu.Unlock() - if werr != nil { + if err != nil { if hasBody { bodyWriter.cancel() } cc.forgetStreamID(cs.ID) // Don't bother sending a RST_STREAM (our write already failed; // no need to keep writing) - traceWroteRequest(cs.trace, werr) + traceWroteRequest(cs.trace, err) // TODO(dneil): An error occurred while writing the headers. // Should we return an error indicating that this request can be retried? - return nil, false, werr + return nil, false, err } var respHeaderTimer <-chan time.Time @@ -1116,7 +1155,6 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf readLoopResCh := cs.resc bodyWritten := false - ctx := req.Context() handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) { res := re.res @@ -1418,19 +1456,17 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( return nil } + cc.wmu.Lock() var trls []byte if hasTrailers { - cc.mu.Lock() trls, err = cc.encodeTrailers(req) - cc.mu.Unlock() if err != nil { + cc.wmu.Unlock() cc.writeStreamReset(cs.ID, ErrCodeInternal, err) cc.forgetStreamID(cs.ID) return err } } - - cc.wmu.Lock() defer cc.wmu.Unlock() // Two ways to send END_STREAM: either with trailers, or @@ -1480,7 +1516,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } } -// requires cc.mu be held. +// requires cc.wmu be held. func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { cc.hbuf.Reset() @@ -1668,7 +1704,7 @@ func shouldSendReqContentLength(method string, contentLength int64) bool { } } -// requires cc.mu be held. +// requires cc.wmu be held. func (cc *ClientConn) encodeTrailers(req *http.Request) ([]byte, error) { cc.hbuf.Reset() @@ -1817,7 +1853,11 @@ func (rl *clientConnReadLoop) cleanup() { } else if err == io.EOF { err = io.ErrUnexpectedEOF } - for _, cs := range cc.streams { + cc.closed = true + streams := cc.streams + cc.streams = nil + cc.mu.Unlock() + for _, cs := range streams { cs.bufPipe.CloseWithError(err) // no-op if already closed select { case cs.resc <- resAndError{err: err}: @@ -1825,7 +1865,7 @@ func (rl *clientConnReadLoop) cleanup() { } close(cs.done) } - cc.closed = true + cc.mu.Lock() cc.cond.Broadcast() cc.mu.Unlock() } @@ -2155,8 +2195,6 @@ func (b transportResponseBody) Read(p []byte) (n int, err error) { } cc.mu.Lock() - defer cc.mu.Unlock() - var connAdd, streamAdd int32 // Check the conn-level first, before the stream-level. if v := cc.inflow.available(); v < transportDefaultConnFlow/2 { @@ -2173,6 +2211,8 @@ func (b transportResponseBody) Read(p []byte) (n int, err error) { cs.inflow.add(streamAdd) } } + cc.mu.Unlock() + if connAdd != 0 || streamAdd != 0 { cc.wmu.Lock() defer cc.wmu.Unlock() @@ -2198,19 +2238,25 @@ func (b transportResponseBody) Close() error { if unread > 0 || !serverSentStreamEnd { cc.mu.Lock() - cc.wmu.Lock() if !serverSentStreamEnd { - cc.fr.WriteRSTStream(cs.ID, ErrCodeCancel) cs.didReset = true } // Return connection-level flow control. if unread > 0 { cc.inflow.add(int32(unread)) + } + cc.mu.Unlock() + + cc.wmu.Lock() + if !serverSentStreamEnd { + cc.fr.WriteRSTStream(cs.ID, ErrCodeCancel) + } + // Return connection-level flow control. + if unread > 0 { cc.fr.WriteWindowUpdate(0, uint32(unread)) } cc.bw.Flush() cc.wmu.Unlock() - cc.mu.Unlock() } cs.bufPipe.BreakWithError(errClosedResponseBody) @@ -2288,6 +2334,10 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { } if refund > 0 { cc.inflow.add(int32(refund)) + } + cc.mu.Unlock() + + if refund > 0 { cc.wmu.Lock() cc.fr.WriteWindowUpdate(0, uint32(refund)) if !didReset { @@ -2297,7 +2347,6 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc.bw.Flush() cc.wmu.Unlock() } - cc.mu.Unlock() if len(data) > 0 && !didReset { if _, err := cs.bufPipe.Write(data); err != nil { @@ -2358,6 +2407,23 @@ func (rl *clientConnReadLoop) processGoAway(f *GoAwayFrame) error { } func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error { + cc := rl.cc + // Locking both mu and wmu here allows frame encoding to read settings with only wmu held. + // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless. + cc.wmu.Lock() + defer cc.wmu.Unlock() + + if err := rl.processSettingsNoWrite(f); err != nil { + return err + } + if !f.IsAck() { + cc.fr.WriteSettingsAck() + cc.bw.Flush() + } + return nil +} + +func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { cc := rl.cc cc.mu.Lock() defer cc.mu.Unlock() @@ -2420,12 +2486,7 @@ func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error { cc.seenSettings = true } - cc.wmu.Lock() - defer cc.wmu.Unlock() - - cc.fr.WriteSettingsAck() - cc.bw.Flush() - return cc.werr + return nil } func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { diff --git a/transport_test.go b/transport_test.go index 97735fe..ab31640 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/hex" "errors" "flag" "fmt" @@ -3261,7 +3262,8 @@ func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { const body = "foo" req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body))) cc := &ClientConn{ - closed: true, + closed: true, + reqHeaderMu: make(chan struct{}, 1), } _, err := cc.RoundTrip(req) if err != errClientConnUnusable { @@ -4990,6 +4992,245 @@ func (rc *closeChecker) isClosed() error { return nil } +// A blockingWriteConn is a net.Conn that blocks in Write after some number of bytes are written. +type blockingWriteConn struct { + net.Conn + writeOnce sync.Once + writec chan struct{} // closed after the write limit is reached + unblockc chan struct{} // closed to unblock writes + count, limit int +} + +func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn { + return &blockingWriteConn{ + Conn: conn, + limit: limit, + writec: make(chan struct{}), + unblockc: make(chan struct{}), + } +} + +// wait waits until the conn blocks writing the limit+1st byte. +func (c *blockingWriteConn) wait() { + <-c.writec +} + +// unblock unblocks writes to the conn. +func (c *blockingWriteConn) unblock() { + close(c.unblockc) +} + +func (c *blockingWriteConn) Write(b []byte) (n int, err error) { + if c.count+len(b) > c.limit { + c.writeOnce.Do(func() { + close(c.writec) + }) + <-c.unblockc + } + n, err = c.Conn.Write(b) + c.count += n + return n, err +} + +// Write several requests to a ClientConn at the same time, looking for race conditions. +// See golang.org/issue/48340 +func TestTransportFrameBufferReuse(t *testing.T) { + filler := hex.EncodeToString([]byte(randString(2048))) + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("Big"), filler; got != want { + t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want) + } + b, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("error reading request body: %v", err) + } + if got, want := string(b), filler; got != want { + t.Errorf("request body = %q, want %q", got, want) + } + if got, want := r.Trailer.Get("Big"), filler; got != want { + t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want) + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + var wg sync.WaitGroup + defer wg.Wait() + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(filler)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Big", filler) + req.Trailer = make(http.Header) + req.Trailer.Set("Big", filler) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + }() + } + +} + +// Ensure that a request blocking while being written to the underlying net.Conn doesn't +// block access to the ClientConn pool. Test requests blocking while writing headers, the body, +// and trailers. +// See golang.org/issue/32388 +func TestTransportBlockingRequestWrite(t *testing.T) { + filler := hex.EncodeToString([]byte(randString(2048))) + for _, test := range []struct { + name string + req func(url string) (*http.Request, error) + }{{ + name: "headers", + req: func(url string) (*http.Request, error) { + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Big", filler) + return req, err + }, + }, { + name: "body", + req: func(url string) (*http.Request, error) { + req, err := http.NewRequest("POST", url, strings.NewReader(filler)) + if err != nil { + return nil, err + } + return req, err + }, + }, { + name: "trailer", + req: func(url string) (*http.Request, error) { + req, err := http.NewRequest("POST", url, strings.NewReader("body")) + if err != nil { + return nil, err + } + req.Trailer = make(http.Header) + req.Trailer.Set("Big", filler) + return req, err + }, + }} { + test := test + t.Run(test.name, func(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if v := r.Header.Get("Big"); v != "" && v != filler { + t.Errorf("request header mismatch") + } + if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler { + t.Errorf("request body mismatch\ngot: %q\nwant: %q", string(v), filler) + } + if v := r.Trailer.Get("Big"); v != "" && v != filler { + t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler) + } + }, optOnlyServer, func(s *Server) { + s.MaxConcurrentStreams = 1 + }) + defer st.Close() + + // This Transport creates connections that block on writes after 1024 bytes. + connc := make(chan *blockingWriteConn, 1) + connCount := 0 + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + connCount++ + c, err := tls.Dial(network, addr, cfg) + wc := newBlockingWriteConn(c, 1024) + select { + case connc <- wc: + default: + } + return wc, err + }, + } + defer tr.CloseIdleConnections() + + // Request 1: A small request to ensure we read the server MaxConcurrentStreams. + { + req, err := http.NewRequest("POST", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + } + + // Request 2: A large request that blocks while being written. + reqc := make(chan struct{}) + go func() { + defer close(reqc) + req, err := test.req(st.ts.URL) + if err != nil { + t.Error(err) + return + } + res, _ := tr.RoundTrip(req) + if res != nil && res.Body != nil { + res.Body.Close() + } + }() + conn := <-connc + conn.wait() // wait for the request to block + + // Request 3: A small request that is sent on a new connection, since request 2 + // is hogging the only available stream on the previous connection. + { + req, err := http.NewRequest("POST", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + } + + // Request 2 should still be blocking at this point. + select { + case <-reqc: + t.Errorf("request 2 unexpectedly completed") + default: + } + + conn.unblock() + <-reqc + + if connCount != 2 { + t.Errorf("created %v connections, want 1", connCount) + } + }) + } +} + func TestTransportCloseRequestBody(t *testing.T) { var statusCode int st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {