From d6a786627a1ca6c879a282949149775500b80231 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 2 Sep 2021 13:22:38 -0700 Subject: [PATCH] [internal-branch.go1.17-vendor] http2: refactor request write flow Move the entire request write into a new writeRequest function, which runs as its own goroutine. The writeRequest function handles all indefintely-blocking operations (in particular, network writes), as well as all post-request cleanup: Closing the request body, sending a RST_STREAM when necessary, releasing the concurrency slot held by the stream, etc. Consolidates several goroutines used to wait for stream slots, write the body, and close response bodies. Ensures that RoundTrip does not block past request cancelation. Updates golang/go#49077 Change-Id: Iaf8bb3e17de89384b031ec4f324918b5720f5877 Reviewed-on: https://go-review.googlesource.com/c/net/+/353390 Trust: Damien Neil Trust: Brad Fitzpatrick Run-TryBot: Damien Neil TryBot-Result: Go Bot Reviewed-by: Brad Fitzpatrick Reviewed-on: https://go-review.googlesource.com/c/net/+/357683 Reviewed-by: Dmitri Shuralyov --- client_conn_pool.go | 1 + pipe.go | 11 + transport.go | 961 ++++++++++++++++++++------------------------ transport_test.go | 96 ++++- 4 files changed, 516 insertions(+), 553 deletions(-) diff --git a/client_conn_pool.go b/client_conn_pool.go index 8fd95bb..7f817e2 100644 --- a/client_conn_pool.go +++ b/client_conn_pool.go @@ -84,6 +84,7 @@ func (p *clientConnPool) shouldTraceGetConn(cc *ClientConn) bool { } func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { + // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? if isConnectionCloseRequest(req) && dialOnMiss { // It gets its own connection. traceGetConn(req, addr) diff --git a/pipe.go b/pipe.go index 2a5399e..c15b8a7 100644 --- a/pipe.go +++ b/pipe.go @@ -30,6 +30,17 @@ type pipeBuffer interface { io.Reader } +// setBuffer initializes the pipe buffer. +// It has no effect if the pipe is already closed. +func (p *pipe) setBuffer(b pipeBuffer) { + p.mu.Lock() + defer p.mu.Unlock() + if p.err != nil || p.breakErr != nil { + return + } + p.b = b +} + func (p *pipe) Len() int { p.mu.Lock() defer p.mu.Unlock() diff --git a/transport.go b/transport.go index d79ace7..40655a1 100644 --- a/transport.go +++ b/transport.go @@ -297,52 +297,42 @@ type clientStream struct { req *http.Request trace *httptrace.ClientTrace // or nil ID uint32 - resc chan resAndError bufPipe pipe // buffered pipe with the flow-controlled response payload - startedWrite bool // started request body write; guarded by cc.mu requestedGzip bool - on100 func() // optional code to run if get a 100 continue response + + abortOnce sync.Once + abort chan struct{} // closed to signal stream should end immediately + abortErr error // set if abort is closed + + peerClosed chan struct{} // closed when the peer sends an END_STREAM flag + donec chan struct{} // closed after the stream is in the closed state + on100 chan struct{} // buffered; written to if a 100 is received + + respHeaderRecv chan struct{} // closed when headers are received + res *http.Response // set if respHeaderRecv is closed flow flow // guarded by cc.mu inflow flow // guarded by cc.mu bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu - didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu - peerReset chan struct{} // closed on peer reset - resetErr error // populated before peerReset is closed - - done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu + // owned by writeRequest: + sentEndStream bool // sent an END_STREAM flag to the peer + sentHeaders bool // owned by clientConnReadLoop: firstByte bool // got the first response byte pastHeaders bool // got first MetaHeadersFrame (actual headers) pastTrailers bool // got optional second MetaHeadersFrame (trailers) num1xx uint8 // number of 1xx responses seen + readClosed bool // peer sent an END_STREAM flag + readAborted bool // read loop reset the stream trailer http.Header // accumulated trailers resTrailer *http.Header // client's Response.Trailer } -// awaitRequestCancel waits for the user to cancel a request or for the done -// channel to be signaled. A non-nil error is returned only if the request was -// canceled. -func awaitRequestCancel(req *http.Request, done <-chan struct{}) error { - ctx := req.Context() - if req.Cancel == nil && ctx.Done() == nil { - return nil - } - select { - case <-req.Cancel: - return errRequestCanceled - case <-ctx.Done(): - return ctx.Err() - case <-done: - return nil - } -} - var got1xxFuncForTests func(int, textproto.MIMEHeader) error // get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, @@ -354,50 +344,24 @@ func (cs *clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error return traceGot1xxResponseFunc(cs.trace) } -// awaitRequestCancel waits for the user to cancel a request, its context to -// expire, or for the request to be done (any way it might be removed from the -// cc.streams map: peer reset, successful completion, TCP connection breakage, -// etc). If the request is canceled, then cs will be canceled and closed. -func (cs *clientStream) awaitRequestCancel(req *http.Request) { - if err := awaitRequestCancel(req, cs.done); err != nil { - cs.cancelStream() - cs.bufPipe.CloseWithError(err) - } -} - -func (cs *clientStream) cancelStream() { - cc := cs.cc - cc.mu.Lock() - didReset := cs.didReset - cs.didReset = true - cc.mu.Unlock() - - if !didReset { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - cc.forgetStreamID(cs.ID) - } +func (cs *clientStream) abortStream(err error) { + cs.cc.mu.Lock() + defer cs.cc.mu.Unlock() + cs.abortStreamLocked(err) } -// checkResetOrDone reports any error sent in a RST_STREAM frame by the -// server, or errStreamClosed if the stream is complete. -func (cs *clientStream) checkResetOrDone() error { - select { - case <-cs.peerReset: - return cs.resetErr - case <-cs.done: - return errStreamClosed - default: - return nil +func (cs *clientStream) abortStreamLocked(err error) { + cs.abortOnce.Do(func() { + cs.abortErr = err + close(cs.abort) + }) + // TODO(dneil): Clean up tests where cs.cc.cond is nil. + if cs.cc.cond != nil { + // Wake up writeRequestBody if it is waiting on flow control. + cs.cc.cond.Broadcast() } } -func (cs *clientStream) getStartedWrite() bool { - cc := cs.cc - cc.mu.Lock() - defer cc.mu.Unlock() - return cs.startedWrite -} - func (cs *clientStream) abortRequestBodyWrite(err error) { if err == nil { panic("nil error") @@ -407,10 +371,6 @@ func (cs *clientStream) abortRequestBodyWrite(err error) { if cs.stopReqBody == nil { cs.stopReqBody = err cc.cond.Broadcast() - // Close the body after releasing the mutex, in case it blocks. - if body := cs.req.Body; body != nil { - defer body.Close() - } } cc.mu.Unlock() } @@ -499,10 +459,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res } reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) traceGotConn(req, cc, reused) - body := req.Body - res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req) + res, err := cc.RoundTrip(req) if err != nil && retry <= 6 { - if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil { + if req, err = shouldRetryRequest(req, err); err == nil { // After the first retry, do exponential backoff with 10% jitter. if retry == 0 { continue @@ -519,11 +478,6 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res } if err != nil { t.vlogf("RoundTrip failure: %v", err) - // If the error occurred after the body write started, - // the body writer will close the body. Otherwise, do so here. - if body != nil && !gotErrAfterReqBodyWrite { - body.Close() - } return nil, err } return res, nil @@ -549,7 +503,7 @@ var ( // response headers. It is always called with a non-nil error. // It returns either a request to retry (either the same request, or a // modified clone), or an error if the request can't be replayed. -func shouldRetryRequest(req *http.Request, err error, afterBodyWrite bool) (*http.Request, error) { +func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) { if !canRetryError(err) { return nil, err } @@ -562,7 +516,6 @@ func shouldRetryRequest(req *http.Request, err error, afterBodyWrite bool) (*htt // If the request body can be reset back to its original // state via the optional req.GetBody, do that. if req.GetBody != nil { - req.Body.Close() body, err := req.GetBody() if err != nil { return nil, err @@ -574,10 +527,8 @@ func shouldRetryRequest(req *http.Request, err error, afterBodyWrite bool) (*htt // The Request.Body can't reset back to the beginning, but we // don't seem to have started to read from it yet, so reuse - // the request directly. The "afterBodyWrite" means the - // bodyWrite process has started, which becomes true before - // the first Read. - if !afterBodyWrite { + // the request directly. + if err == errClientConnUnusable { return req, nil } @@ -769,10 +720,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { last := f.LastStreamID for streamID, cs := range cc.streams { if streamID > last { - select { - case cs.resc <- resAndError{err: errClientConnGotGoAway}: - default: - } + cs.abortStreamLocked(errClientConnGotGoAway) } } } @@ -864,7 +812,7 @@ func (cc *ClientConn) onIdleTimeout() { func (cc *ClientConn) closeIfIdle() { cc.mu.Lock() - if len(cc.streams) > 0 { + if len(cc.streams) > 0 || cc.streamsReserved > 0 { cc.mu.Unlock() return } @@ -887,7 +835,7 @@ func (cc *ClientConn) isDoNotReuseAndIdle() bool { var shutdownEnterWaitStateHook = func() {} -// Shutdown gracefully close the client connection, waiting for running streams to complete. +// Shutdown gracefully closes the client connection, waiting for running streams to complete. func (cc *ClientConn) Shutdown(ctx context.Context) error { if err := cc.sendGoAway(); err != nil { return err @@ -952,20 +900,10 @@ func (cc *ClientConn) sendGoAway() error { // err is sent to streams. func (cc *ClientConn) closeForError(err error) error { cc.mu.Lock() - streams := cc.streams - cc.streams = nil cc.closed = true - cc.mu.Unlock() - - for _, cs := range streams { - select { - case cs.resc <- resAndError{err: err}: - default: - } - cs.bufPipe.CloseWithError(err) + for _, cs := range cc.streams { + cs.abortStreamLocked(err) } - - cc.mu.Lock() defer cc.cond.Broadcast() defer cc.mu.Unlock() return cc.tconn.Close() @@ -1059,66 +997,119 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { - resp, _, err := cc.roundTrip(req) - return resp, err -} - -func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) { ctx := req.Context() - if err := checkConnHeaders(req); err != nil { - cc.decrStreamReservations() - return nil, false, err + cs := &clientStream{ + cc: cc, + req: req, + trace: httptrace.ContextClientTrace(req.Context()), + peerClosed: make(chan struct{}), + abort: make(chan struct{}), + respHeaderRecv: make(chan struct{}), + donec: make(chan struct{}), } - if cc.idleTimer != nil { - cc.idleTimer.Stop() + go cs.doRequest() + + waitDone := func() error { + select { + case <-cs.donec: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-req.Cancel: + return errRequestCanceled + } } - trailers, err := commaSeparatedTrailers(req) - if err != nil { - cc.decrStreamReservations() - return nil, false, err + for { + select { + case <-cs.respHeaderRecv: + res := cs.res + if res.StatusCode > 299 { + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. + cs.abortRequestBodyWrite(errStopReqBodyWrite) + } + res.Request = req + res.TLS = cc.tlsState + if res.Body == noBody && actualContentLength(req) == 0 { + // If there isn't a request or response body still being + // written, then wait for the stream to be closed before + // RoundTrip returns. + if err := waitDone(); err != nil { + return nil, err + } + } + return res, nil + case <-cs.abort: + waitDone() + return nil, cs.abortErr + case <-ctx.Done(): + return nil, ctx.Err() + case <-req.Cancel: + return nil, errRequestCanceled + } + } +} + +// writeRequest runs for the duration of the request lifetime. +// +// It sends the request and performs post-request cleanup (closing Request.Body, etc.). +func (cs *clientStream) doRequest() { + err := cs.writeRequest() + cs.cleanupWriteRequest(err) +} + +// writeRequest sends a request. +// +// It returns nil after the request is written, the response read, +// and the request stream is half-closed by the peer. +// +// It returns non-nil if the request ends otherwise. +// If the returned error is StreamError, the error Code may be used in resetting the stream. +func (cs *clientStream) writeRequest() (err error) { + cc := cs.cc + req := cs.req + ctx := req.Context() + + if err := checkConnHeaders(cs.req); err != nil { + return err } - hasTrailers := trailers != "" // Acquire the new-request lock by writing to reqHeaderMu. // This lock guards the critical section covering allocating a new stream ID // (requires mu) and creating the stream (requires wmu). if cc.reqHeaderMu == nil { - panic("RoundTrip on initialized ClientConn") // for tests + panic("RoundTrip on uninitialized ClientConn") // for tests } select { case cc.reqHeaderMu <- struct{}{}: case <-req.Cancel: - cc.decrStreamReservations() - return nil, false, errRequestCanceled + return errRequestCanceled case <-ctx.Done(): - cc.decrStreamReservations() - return nil, false, ctx.Err() + return ctx.Err() } - reqHeaderMuNeedsUnlock := true - defer func() { - if reqHeaderMuNeedsUnlock { - <-cc.reqHeaderMu - } - }() cc.mu.Lock() - cc.decrStreamReservationsLocked() - if req.URL == nil { - cc.mu.Unlock() - return nil, false, errNilRequestURL + if cc.idleTimer != nil { + cc.idleTimer.Stop() } - if err := cc.awaitOpenSlotForRequest(req); err != nil { + cc.decrStreamReservationsLocked() + if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil { cc.mu.Unlock() - return nil, false, err + <-cc.reqHeaderMu + return err } - - body := req.Body - contentLen := actualContentLength(req) - hasBody := contentLen != 0 + cc.addStreamLocked(cs) // assigns stream ID + cc.mu.Unlock() // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - var requestedGzip bool if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && @@ -1135,185 +1126,218 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf // We don't request gzip if the request is for a range, since // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 - requestedGzip = true + cs.requestedGzip = true } - cs := cc.newStream() - cs.req = req - cs.trace = httptrace.ContextClientTrace(req.Context()) - cs.requestedGzip = requestedGzip - bodyWriter := cc.t.getBodyWriterState(cs, body) - cs.on100 = bodyWriter.on100 - cc.mu.Unlock() + continueTimeout := cc.t.expectContinueTimeout() + if continueTimeout != 0 && + !httpguts.HeaderValuesContainsToken( + cs.req.Header["Expect"], + "100-continue") { + continueTimeout = 0 + cs.on100 = make(chan struct{}, 1) + } + err = cs.encodeAndWriteHeaders() + <-cc.reqHeaderMu + if err != nil { + return err + } + + hasBody := actualContentLength(cs.req) != 0 + if !hasBody { + cs.sentEndStream = true + } else { + if continueTimeout != 0 { + traceWait100Continue(cs.trace) + timer := time.NewTimer(continueTimeout) + select { + case <-timer.C: + err = nil + case <-cs.on100: + err = nil + case <-cs.abort: + err = cs.abortErr + case <-ctx.Done(): + err = ctx.Err() + case <-req.Cancel: + err = errRequestCanceled + } + timer.Stop() + if err != nil { + traceWroteRequest(cs.trace, err) + return err + } + } + + if err = cs.writeRequestBody(req.Body); err != nil { + if err != errStopReqBodyWrite { + traceWroteRequest(cs.trace, err) + return err + } + } else { + cs.sentEndStream = true + } + } + + traceWroteRequest(cs.trace, err) + + var respHeaderTimer <-chan time.Time + var respHeaderRecv chan struct{} + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + respHeaderRecv = cs.respHeaderRecv + } + // Wait until the peer half-closes its end of the stream, + // or until the request is aborted (via context, error, or otherwise), + // whichever comes first. + for { + select { + case <-cs.peerClosed: + return nil + case <-respHeaderTimer: + return errTimeout + case <-respHeaderRecv: + respHeaderTimer = nil // keep waiting for END_STREAM + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-req.Cancel: + return errRequestCanceled + } + } +} + +func (cs *clientStream) encodeAndWriteHeaders() error { + cc := cs.cc + req := cs.req + ctx := req.Context() + + cc.wmu.Lock() + defer cc.wmu.Unlock() + + // If the request was canceled while waiting for cc.mu, just quit. + select { + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-req.Cancel: + return errRequestCanceled + default: + } + + // Encode headers. + // // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is // sent by writeRequestBody below, along with any Trailers, // again in form HEADERS{1}, CONTINUATION{0,}) - cc.wmu.Lock() - hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) + trailers, err := commaSeparatedTrailers(cs.req) if err != nil { - cc.wmu.Unlock() - return nil, false, err + return err + } + hasTrailers := trailers != "" + contentLen := actualContentLength(cs.req) + hasBody := contentLen != 0 + hdrs, err := cc.encodeHeaders(cs.req, cs.requestedGzip, trailers, contentLen) + if err != nil { + return err } - defer func() { - cc.wmu.Lock() - werr := cc.werr - cc.wmu.Unlock() - if werr != nil { - cc.Close() - } - }() - + // Write the request. endStream := !hasBody && !hasTrailers + cs.sentHeaders = true err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) - cc.wmu.Unlock() - <-cc.reqHeaderMu // release the new-request lock - reqHeaderMuNeedsUnlock = false traceWroteHeaders(cs.trace) + return err +} - if err != nil { - if hasBody { - bodyWriter.cancel() +// cleanupWriteRequest performs post-request tasks. +// +// If err (the result of writeRequest) is non-nil and the stream is not closed, +// cleanupWriteRequest will send a reset to the peer. +func (cs *clientStream) cleanupWriteRequest(err error) { + cc := cs.cc + req := cs.req + + if cs.ID == 0 { + // We were canceled before creating the stream, so return our reservation. + cc.decrStreamReservations() + } + + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body + if req.Body != nil { + if e := req.Body.Close(); err == nil { + err = e } - cc.forgetStreamID(cs.ID) - // Don't bother sending a RST_STREAM (our write already failed; - // no need to keep writing) - traceWroteRequest(cs.trace, err) - // TODO(dneil): An error occurred while writing the headers. - // Should we return an error indicating that this request can be retried? - return nil, false, err } - var respHeaderTimer <-chan time.Time - if hasBody { - bodyWriter.scheduleBodyWrite() - } else { - traceWroteRequest(cs.trace, nil) - if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) - defer timer.Stop() - respHeaderTimer = timer.C - } - } - - readLoopResCh := cs.resc - bodyWritten := false - - handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) { - res := re.res - if re.err != nil || res.StatusCode > 299 { - // On error or status code 3xx, 4xx, 5xx, etc abort any - // ongoing write, assuming that the server doesn't care - // about our request body. If the server replied with 1xx or - // 2xx, however, then assume the server DOES potentially - // want our body (e.g. full-duplex streaming: - // golang.org/issue/13444). If it turns out the server - // doesn't, they'll RST_STREAM us soon enough. This is a - // heuristic to avoid adding knobs to Transport. Hopefully - // we can keep it. - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWrite) - if hasBody && !bodyWritten { - <-bodyWriter.resc + if err != nil && cs.sentEndStream { + // If the connection is closed immediately after the response is read, + // we may be aborted before finishing up here. If the stream was closed + // cleanly on both sides, there is no error. + select { + case <-cs.peerClosed: + err = nil + default: + } + } + if err != nil { + cs.abortStream(err) // possibly redundant, but harmless + if cs.sentHeaders { + if se, ok := err.(StreamError); ok { + if se.Cause != errFromPeer { + cc.writeStreamReset(cs.ID, se.Code, err) + } + } else { + cc.writeStreamReset(cs.ID, ErrCodeCancel, err) } } - if re.err != nil { - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), re.err + cs.bufPipe.CloseWithError(err) // no-op if already closed + } else { + if cs.sentHeaders && !cs.sentEndStream { + cc.writeStreamReset(cs.ID, ErrCodeNo, nil) } - res.Request = req - res.TLS = cc.tlsState - return res, false, nil + cs.bufPipe.CloseWithError(errRequestCanceled) } - - handleError := func(err error) (*http.Response, bool, error) { - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } + if cs.ID != 0 { cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), err } + close(cs.donec) - for { - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - case <-respHeaderTimer: - return handleError(errTimeout) - case <-ctx.Done(): - return handleError(ctx.Err()) - case <-req.Cancel: - return handleError(errRequestCanceled) - case <-cs.peerReset: - // processResetStream already removed the - // stream from the streams map; no need for - // forgetStreamID. - return nil, cs.getStartedWrite(), cs.resetErr - case err := <-bodyWriter.resc: - bodyWritten = true - // Prefer the read loop's response, if available. Issue 16102. - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - if err != nil { - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), err - } - if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) - defer timer.Stop() - respHeaderTimer = timer.C - } - } + cc.wmu.Lock() + werr := cc.werr + cc.wmu.Unlock() + if werr != nil { + cc.Close() } } -// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams. +// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. // Must hold cc.mu. -func (cc *ClientConn) awaitOpenSlotForRequest(req *http.Request) error { - var waitingForConn chan struct{} - var waitingForConnErr error // guarded by cc.mu +func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { for { cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { - if waitingForConn != nil { - close(waitingForConn) - } return errClientConnUnusable } cc.lastIdle = time.Time{} if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) { - if waitingForConn != nil { - close(waitingForConn) - } return nil } - // Unfortunately, we cannot wait on a condition variable and channel at - // the same time, so instead, we spin up a goroutine to check if the - // request is canceled while we wait for a slot to open in the connection. - if waitingForConn == nil { - waitingForConn = make(chan struct{}) - go func() { - if err := awaitRequestCancel(req, waitingForConn); err != nil { - cc.mu.Lock() - waitingForConnErr = err - cc.cond.Broadcast() - cc.mu.Unlock() - } - }() - } cc.pendingRequests++ cc.cond.Wait() cc.pendingRequests-- - if waitingForConnErr != nil { - return waitingForConnErr + select { + case <-cs.abort: + return cs.abortErr + default: } } } @@ -1340,10 +1364,6 @@ func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize cc.fr.WriteContinuation(streamID, endHeaders, chunk) } } - // TODO(bradfitz): this Flush could potentially block (as - // could the WriteHeaders call(s) above), which means they - // wouldn't respond to Request.Cancel being readable. That's - // rare, but this should probably be in a goroutine. cc.bw.Flush() return cc.werr } @@ -1385,28 +1405,10 @@ func (cs *clientStream) frameScratchBufferLen(maxFrameSize int) int { var bufPool sync.Pool // of *[]byte -func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { +func (cs *clientStream) writeRequestBody(body io.Reader) (err error) { cc := cs.cc sentEnd := false // whether we sent the final DATA frame w/ END_STREAM - defer func() { - traceWroteRequest(cs.trace, err) - // TODO: write h12Compare test showing whether - // Request.Body is closed by the Transport, - // and in multiple cases: server replies <=299 and >299 - // while still writing request body - var cerr error - cc.mu.Lock() - if cs.stopReqBody == nil { - cs.stopReqBody = errStopReqBodyWrite - cerr = bodyCloser.Close() - } - cc.mu.Unlock() - if err == nil { - err = cerr - } - }() - req := cs.req hasTrailers := req.Trailer != nil remainLen := actualContentLength(req) @@ -1447,7 +1449,6 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( } if remainLen < 0 { err = errReqBodyTooLong - cc.writeStreamReset(cs.ID, ErrCodeCancel, err) return err } } @@ -1455,7 +1456,6 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( sawEOF = true err = nil } else if err != nil { - cc.writeStreamReset(cs.ID, ErrCodeCancel, err) return err } @@ -1467,7 +1467,6 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( case err == errStopReqBodyWrite: return err case err == errStopReqBodyWriteAndCancel: - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) return err case err != nil: return err @@ -1506,8 +1505,6 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( trls, err = cc.encodeTrailers(req) if err != nil { cc.wmu.Unlock() - cc.writeStreamReset(cs.ID, ErrCodeInternal, err) - cc.forgetStreamID(cs.ID) return err } } @@ -1532,6 +1529,8 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( // if the stream is dead. func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { cc := cs.cc + req := cs.req + ctx := req.Context() cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1541,8 +1540,14 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) if cs.stopReqBody != nil { return 0, cs.stopReqBody } - if err := cs.checkResetOrDone(); err != nil { - return 0, err + select { + case <-cs.abort: + return 0, cs.abortErr + case <-ctx.Done(): + return 0, ctx.Err() + case <-req.Cancel: + return 0, errRequestCanceled + default: } if a := cs.flow.available(); a > 0 { take := a @@ -1798,51 +1803,51 @@ type resAndError struct { } // requires cc.mu be held. -func (cc *ClientConn) newStream() *clientStream { - cs := &clientStream{ - cc: cc, - ID: cc.nextStreamID, - resc: make(chan resAndError, 1), - peerReset: make(chan struct{}), - done: make(chan struct{}), - } +func (cc *ClientConn) addStreamLocked(cs *clientStream) { cs.flow.add(int32(cc.initialWindowSize)) cs.flow.setConnFlow(&cc.flow) cs.inflow.add(transportDefaultStreamFlow) cs.inflow.setConnFlow(&cc.inflow) + cs.ID = cc.nextStreamID cc.nextStreamID += 2 cc.streams[cs.ID] = cs - return cs + if cs.ID == 0 { + panic("assigned stream ID 0") + } } func (cc *ClientConn) forgetStreamID(id uint32) { - cc.streamByID(id, true) -} - -func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream { cc.mu.Lock() - defer cc.mu.Unlock() - cs := cc.streams[id] - if andRemove && cs != nil && !cc.closed { - cc.lastActive = time.Now() - delete(cc.streams, id) - if len(cc.streams) == 0 && cc.idleTimer != nil { - cc.idleTimer.Reset(cc.idleTimeout) - cc.lastIdle = time.Now() - } - close(cs.done) - // Wake up checkResetOrDone via clientStream.awaitFlowControl and - // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + slen := len(cc.streams) + delete(cc.streams, id) + if len(cc.streams) != slen-1 { + panic("forgetting unknown stream id") + } + cc.lastActive = time.Now() + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + cc.lastIdle = time.Now() + } + // Wake up writeRequestBody via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() + + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() + if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { + if VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) + } + cc.closed = true + defer cc.tconn.Close() } - return cs + + cc.mu.Unlock() } // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. type clientConnReadLoop struct { - _ incomparable - cc *ClientConn - closeWhenIdle bool + _ incomparable + cc *ClientConn } // readLoop runs in its own goroutine and reads and dispatches frames. @@ -1903,26 +1908,15 @@ func (rl *clientConnReadLoop) cleanup() { err = io.ErrUnexpectedEOF } cc.closed = true - streams := cc.streams - cc.streams = nil - cc.mu.Unlock() - for _, cs := range streams { - cs.bufPipe.CloseWithError(err) // no-op if already closed - select { - case cs.resc <- resAndError{err: err}: - default: - } - close(cs.done) + for _, cs := range cc.streams { + cs.abortStreamLocked(err) } - cc.mu.Lock() cc.cond.Broadcast() cc.mu.Unlock() } func (rl *clientConnReadLoop) run() error { cc := rl.cc - rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse - gotReply := false // ever saw a HEADERS reply gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout var t *time.Timer @@ -1939,9 +1933,7 @@ func (rl *clientConnReadLoop) run() error { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } if se, ok := err.(StreamError); ok { - if cs := cc.streamByID(se.StreamID, false); cs != nil { - cs.cc.writeStreamReset(cs.ID, se.Code, err) - cs.cc.forgetStreamID(cs.ID) + if cs := rl.streamByID(se.StreamID); cs != nil { if se.Cause == nil { se.Cause = cc.fr.errDetail } @@ -1961,22 +1953,16 @@ func (rl *clientConnReadLoop) run() error { } gotSettings = true } - maybeIdle := false // whether frame might transition us to idle switch f := f.(type) { case *MetaHeadersFrame: err = rl.processHeaders(f) - maybeIdle = true - gotReply = true case *DataFrame: err = rl.processData(f) - maybeIdle = true case *GoAwayFrame: err = rl.processGoAway(f) - maybeIdle = true case *RSTStreamFrame: err = rl.processResetStream(f) - maybeIdle = true case *SettingsFrame: err = rl.processSettings(f) case *PushPromiseFrame: @@ -1994,38 +1980,24 @@ func (rl *clientConnReadLoop) run() error { } return err } - if rl.closeWhenIdle && gotReply && maybeIdle { - cc.closeIfIdle() - } } } func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { - cc := rl.cc - cs := cc.streamByID(f.StreamID, false) + cs := rl.streamByID(f.StreamID) if cs == nil { // We'd get here if we canceled a request while the // server had its response still in flight. So if this // was just something we canceled, ignore it. return nil } - if f.StreamEnded() { - // Issue 20521: If the stream has ended, streamByID() causes - // clientStream.done to be closed, which causes the request's bodyWriter - // to be closed with an errStreamClosed, which may be received by - // clientConn.RoundTrip before the result of processing these headers. - // Deferring stream closure allows the header processing to occur first. - // clientConn.RoundTrip may still receive the bodyWriter error first, but - // the fix for issue 16102 prioritises any response. - // - // Issue 22413: If there is no request body, we should close the - // stream before writing to cs.resc so that the stream is closed - // immediately once RoundTrip returns. - if cs.req.Body != nil { - defer cc.forgetStreamID(f.StreamID) - } else { - cc.forgetStreamID(f.StreamID) - } + if cs.readClosed { + rl.endStreamError(cs, StreamError{ + StreamID: f.StreamID, + Code: ErrCodeProtocol, + Cause: errors.New("protocol error: headers after END_STREAM"), + }) + return nil } if !cs.firstByte { if cs.trace != nil { @@ -2049,9 +2021,11 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { return err } // Any other error type is a stream error. - cs.cc.writeStreamReset(f.StreamID, ErrCodeProtocol, err) - cc.forgetStreamID(cs.ID) - cs.resc <- resAndError{err: err} + rl.endStreamError(cs, StreamError{ + StreamID: f.StreamID, + Code: ErrCodeProtocol, + Cause: err, + }) return nil // return nil from process* funcs to keep conn alive } if res == nil { @@ -2059,7 +2033,11 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { return nil } cs.resTrailer = &res.Trailer - cs.resc <- resAndError{res: res} + cs.res = res + close(cs.respHeaderRecv) + if f.StreamEnded() { + rl.endStream(cs) + } return nil } @@ -2121,6 +2099,9 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra } if statusCode >= 100 && statusCode <= 199 { + if f.StreamEnded() { + return nil, errors.New("1xx informational response with END_STREAM flag") + } cs.num1xx++ const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http if cs.num1xx > max1xxResponses { @@ -2133,8 +2114,9 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra } if statusCode == 100 { traceGot100Continue(cs.trace) - if cs.on100 != nil { - cs.on100() // forces any write delay timer to fire + select { + case cs.on100 <- struct{}{}: + default: } } cs.pastHeaders = false // do it all again @@ -2163,10 +2145,9 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra return res, nil } - cs.bufPipe = pipe{b: &dataBuffer{expected: res.ContentLength}} + cs.bufPipe.setBuffer(&dataBuffer{expected: res.ContentLength}) cs.bytesRemain = res.ContentLength res.Body = transportResponseBody{cs} - go cs.awaitRequestCancel(cs.req) if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") @@ -2226,7 +2207,7 @@ func (b transportResponseBody) Read(p []byte) (n int, err error) { n = int(cs.bytesRemain) if err == nil { err = errors.New("net/http: server replied with more than declared Content-Length; truncated") - cc.writeStreamReset(cs.ID, ErrCodeProtocol, err) + cs.abortStream(err) } cs.readErr = err return int(cs.bytesRemain), err @@ -2282,14 +2263,9 @@ func (b transportResponseBody) Close() error { cs := b.cs cc := cs.cc - serverSentStreamEnd := cs.bufPipe.Err() == io.EOF unread := cs.bufPipe.Len() - - if unread > 0 || !serverSentStreamEnd { + if unread > 0 { cc.mu.Lock() - if !serverSentStreamEnd { - cs.didReset = true - } // Return connection-level flow control. if unread > 0 { cc.inflow.add(int32(unread)) @@ -2297,9 +2273,6 @@ func (b transportResponseBody) Close() error { cc.mu.Unlock() cc.wmu.Lock() - if !serverSentStreamEnd { - cc.fr.WriteRSTStream(cs.ID, ErrCodeCancel) - } // Return connection-level flow control. if unread > 0 { cc.fr.WriteWindowUpdate(0, uint32(unread)) @@ -2309,16 +2282,21 @@ func (b transportResponseBody) Close() error { } cs.bufPipe.BreakWithError(errClosedResponseBody) - cc.forgetStreamID(cs.ID) + cs.abortStream(errClosedResponseBody) + + select { + case <-cs.donec: + case <-cs.req.Context().Done(): + return cs.req.Context().Err() + case <-cs.req.Cancel: + return errRequestCanceled + } return nil } func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc := rl.cc - cs := cc.streamByID(f.StreamID, f.StreamEnded()) - if f.StreamEnded() && cc.isDoNotReuseAndIdle() { - rl.closeWhenIdle = true - } + cs := rl.streamByID(f.StreamID) data := f.Data() if cs == nil { cc.mu.Lock() @@ -2347,6 +2325,14 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { } return nil } + if cs.readClosed { + cc.logf("protocol error: received DATA after END_STREAM") + rl.endStreamError(cs, StreamError{ + StreamID: f.StreamID, + Code: ErrCodeProtocol, + }) + return nil + } if !cs.firstByte { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, StreamError{ @@ -2378,12 +2364,18 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { if pad := int(f.Length) - len(data); pad > 0 { refund += pad } - // Return len(data) now if the stream is already closed, - // since data will never be read. - didReset := cs.didReset - if didReset { - refund += len(data) + + didReset := false + var err error + if len(data) > 0 { + if _, err = cs.bufPipe.Write(data); err != nil { + // Return len(data) now if the stream is already closed, + // since data will never be read. + didReset = true + refund += len(data) + } } + if refund > 0 { cc.inflow.add(int32(refund)) if !didReset { @@ -2402,11 +2394,9 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc.wmu.Unlock() } - if len(data) > 0 && !didReset { - if _, err := cs.bufPipe.Write(data); err != nil { - rl.endStreamError(cs, err) - return err - } + if err != nil { + rl.endStreamError(cs, err) + return nil } } @@ -2419,24 +2409,26 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { func (rl *clientConnReadLoop) endStream(cs *clientStream) { // TODO: check that any declared content-length matches, like // server.go's (*stream).endStream method. - rl.endStreamError(cs, nil) + if !cs.readClosed { + cs.readClosed = true + cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers) + close(cs.peerClosed) + } } func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) { - var code func() - if err == nil { - err = io.EOF - code = cs.copyTrailers - } - if isConnectionCloseRequest(cs.req) { - rl.closeWhenIdle = true - } - cs.bufPipe.closeWithErrorAndCode(err, code) + cs.readAborted = true + cs.abortStream(err) +} - select { - case cs.resc <- resAndError{err: err}: - default: +func (rl *clientConnReadLoop) streamByID(id uint32) *clientStream { + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs := rl.cc.streams[id] + if cs != nil && !cs.readAborted { + return cs } + return nil } func (cs *clientStream) copyTrailers() { @@ -2545,7 +2537,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { cc := rl.cc - cs := cc.streamByID(f.StreamID, false) + cs := rl.streamByID(f.StreamID) if f.StreamID != 0 && cs == nil { return nil } @@ -2565,33 +2557,19 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { } func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error { - cc := rl.cc - cs := cc.streamByID(f.StreamID, true) + cs := rl.streamByID(f.StreamID) if cs == nil { // TODO: return error if server tries to RST_STEAM an idle stream return nil } - if cc.isDoNotReuseAndIdle() { - rl.closeWhenIdle = true - } - select { - case <-cs.peerReset: - // Already reset. - // This is the only goroutine - // which closes this, so there - // isn't a race. - default: - serr := streamError(cs.ID, f.ErrCode) - if f.ErrCode == ErrCodeProtocol { - rl.cc.SetDoNotReuse() - serr.Cause = errFromPeer - rl.closeWhenIdle = true - } - cs.resetErr = serr - close(cs.peerReset) - cs.bufPipe.CloseWithError(serr) - cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl + serr := streamError(cs.ID, f.ErrCode) + serr.Cause = errFromPeer + if f.ErrCode == ErrCodeProtocol { + rl.cc.SetDoNotReuse() } + cs.abortStream(serr) + + cs.bufPipe.CloseWithError(serr) return nil } @@ -2747,87 +2725,6 @@ type errorReader struct{ err error } func (r errorReader) Read(p []byte) (int, error) { return 0, r.err } -// bodyWriterState encapsulates various state around the Transport's writing -// of the request body, particularly regarding doing delayed writes of the body -// when the request contains "Expect: 100-continue". -type bodyWriterState struct { - cs *clientStream - timer *time.Timer // if non-nil, we're doing a delayed write - fnonce *sync.Once // to call fn with - fn func() // the code to run in the goroutine, writing the body - resc chan error // result of fn's execution - delay time.Duration // how long we should delay a delayed write for -} - -func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s bodyWriterState) { - s.cs = cs - if body == nil { - return - } - resc := make(chan error, 1) - s.resc = resc - s.fn = func() { - cs.cc.mu.Lock() - cs.startedWrite = true - cs.cc.mu.Unlock() - resc <- cs.writeRequestBody(body, cs.req.Body) - } - s.delay = t.expectContinueTimeout() - if s.delay == 0 || - !httpguts.HeaderValuesContainsToken( - cs.req.Header["Expect"], - "100-continue") { - return - } - s.fnonce = new(sync.Once) - - // Arm the timer with a very large duration, which we'll - // intentionally lower later. It has to be large now because - // we need a handle to it before writing the headers, but the - // s.delay value is defined to not start until after the - // request headers were written. - const hugeDuration = 365 * 24 * time.Hour - s.timer = time.AfterFunc(hugeDuration, func() { - s.fnonce.Do(s.fn) - }) - return -} - -func (s bodyWriterState) cancel() { - if s.timer != nil { - if s.timer.Stop() { - s.resc <- nil - } - } -} - -func (s bodyWriterState) on100() { - if s.timer == nil { - // If we didn't do a delayed write, ignore the server's - // bogus 100 continue response. - return - } - s.timer.Stop() - go func() { s.fnonce.Do(s.fn) }() -} - -// scheduleBodyWrite starts writing the body, either immediately (in -// the common case) or after the delay timeout. It should not be -// called until after the headers have been written. -func (s bodyWriterState) scheduleBodyWrite() { - if s.timer == nil { - // We're not doing a delayed write (see - // getBodyWriterState), so just start the writing - // goroutine immediately. - go s.fn() - return - } - traceWait100Continue(s.cs.trace) - if s.timer.Stop() { - s.timer.Reset(s.delay) - } -} - // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. func isConnectionCloseRequest(req *http.Request) bool { diff --git a/transport_test.go b/transport_test.go index 721c6c9..dfb556b 100644 --- a/transport_test.go +++ b/transport_test.go @@ -923,6 +923,10 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { return err } } + case *RSTStreamFrame: + if status == 200 { + return fmt.Errorf("Unexpected client frame %v", f) + } default: return fmt.Errorf("Unexpected client frame %v", f) } @@ -1745,14 +1749,17 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { ct.client = func() error { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := ct.tr.RoundTrip(req) + if e, ok := err.(StreamError); ok { + err = e.Cause + } if err != errResponseHeaderListSize { + size := int64(0) if res != nil { res.Body.Close() - } - size := int64(0) - for k, vv := range res.Header { - for _, v := range vv { - size += int64(len(k)) + int64(len(v)) + 32 + for k, vv := range res.Header { + for _, v := range vv { + size += int64(len(k)) + int64(len(v)) + 32 + } } } return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) @@ -1877,8 +1884,9 @@ func TestTransportBodyReadErrorType(t *testing.T) { doPanic <- true buf := make([]byte, 100) n, err := res.Body.Read(buf) + got, ok := err.(StreamError) want := StreamError{StreamID: 0x1, Code: 0x2} - if !reflect.DeepEqual(want, err) { + if !ok || got.StreamID != want.StreamID || got.Code != want.Code { t.Errorf("Read = %v, %#v; want error %#v", n, err, want) } } @@ -2849,27 +2857,36 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { } waitingFor := "RSTStreamFrame" - for { + sawRST := false + sawWUF := false + for !sawRST && !sawWUF { f, err := ct.fr.ReadFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err) } - if _, ok := f.(*SettingsFrame); ok { - continue - } - switch waitingFor { - case "RSTStreamFrame": - if rf, ok := f.(*RSTStreamFrame); !ok || rf.ErrCode != ErrCodeCancel { + switch f := f.(type) { + case *SettingsFrame: + case *RSTStreamFrame: + if sawRST { + return fmt.Errorf("saw second RSTStreamFrame: %v", summarizeFrame(f)) + } + if f.ErrCode != ErrCodeCancel { return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } - waitingFor = "WindowUpdateFrame" - case "WindowUpdateFrame": - if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != 4999 { - return fmt.Errorf("Expected WindowUpdateFrame for 4999 bytes; got %v", summarizeFrame(f)) + sawRST = true + case *WindowUpdateFrame: + if sawWUF { + return fmt.Errorf("saw second WindowUpdateFrame: %v", summarizeFrame(f)) + } + if f.Increment != 4999 { + return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } - return nil + sawWUF = true + default: + return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) } } + return nil } ct.run() } @@ -3800,7 +3817,7 @@ func TestTransportResponseDataBeforeHeaders(t *testing.T) { return err } switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: + case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: case *HeadersFrame: switch f.StreamID { case 1: @@ -4498,8 +4515,7 @@ func TestTransportUsesGetBodyWhenPresent(t *testing.T) { }, } - afterBodyWrite := false // pretend we haven't read+written the body yet - req2, err := shouldRetryRequest(req, errClientConnUnusable, afterBodyWrite) + req2, err := shouldRetryRequest(req, errClientConnUnusable) if err != nil { t.Fatal(err) } @@ -5301,8 +5317,10 @@ func TestClientConnReservations(t *testing.T) { reqHeaderMu: make(chan struct{}, 1), streams: make(map[uint32]*clientStream), maxConcurrentStreams: initialMaxConcurrentStreams, + nextStreamID: 1, t: &Transport{}, } + cc.cond = sync.NewCond(&cc.mu) n := 0 for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { n++ @@ -5334,3 +5352,39 @@ func TestClientConnReservations(t *testing.T) { t.Errorf("after reset, reservations = %v; want %v", n2, n) } } + +func TestTransportTimeoutServerHangs(t *testing.T) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req = req.WithContext(ctx) + req.Header.Add("Big", strings.Repeat("a", 1<<20)) + _, err = ct.tr.RoundTrip(req) + if err == nil { + return errors.New("error should not be nil") + } + if ne, ok := err.(net.Error); !ok || !ne.Timeout() { + return fmt.Errorf("error should be a net error timeout: %v", err) + } + return nil + } + ct.server = func() error { + ct.greet() + select { + case <-time.After(5 * time.Second): + case <-clientDone: + } + return nil + } + ct.run() +}