Skip to content

Commit

Permalink
[internal-branch.go1.17-vendor] http2: close the request body if needed
Browse files Browse the repository at this point in the history
As per client.Do and Request.Body, the transport is responsible to close
the request Body.
If there was an error or non 1xx/2xx status code, the transport will
wait for the body writer to complete. If there is no data available to
read, the body writer will block indefinitely. To prevent this, the body
will be closed if it hasn't already.
If there was a 1xx/2xx status code, the body will be closed eventually.

Updates golang/go#49077

Change-Id: I9a4a5f13658122c562baf915e2c0c8992a023278
Reviewed-on: https://go-review.googlesource.com/c/net/+/323689
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Trust: Alexander Rakoczy <alex@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-on: https://go-review.googlesource.com/c/net/+/357671
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
  • Loading branch information
fraenkel authored and dmitshur committed Oct 29, 2021
1 parent 136e584 commit 553fb77
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
59 changes: 29 additions & 30 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,13 @@ func (cs *clientStream) abortRequestBodyWrite(err error) {
}
cc := cs.cc
cc.mu.Lock()
cs.stopReqBody = err
cc.cond.Broadcast()
if cs.stopReqBody == nil {
cs.stopReqBody = err
if cs.req.Body != nil {
cs.req.Body.Close()
}
cc.cond.Broadcast()
}
cc.mu.Unlock()
}

Expand Down Expand Up @@ -1110,40 +1115,28 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
return res, false, nil
}

handleError := func(err error) (*http.Response, bool, error) {
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), err
}

for {
select {
case re := <-readLoopResCh:
return handleReadLoopResponse(re)
case <-respHeaderTimer:
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errTimeout
return handleError(errTimeout)
case <-ctx.Done():
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), ctx.Err()
return handleError(ctx.Err())
case <-req.Cancel:
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errRequestCanceled
return handleError(errRequestCanceled)
case <-cs.peerReset:
// processResetStream already removed the
// stream from the streams map; no need for
Expand Down Expand Up @@ -1290,7 +1283,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
// Request.Body is closed by the Transport,
// and in multiple cases: server replies <=299 and >299
// while still writing request body
cerr := bodyCloser.Close()
var cerr error
cc.mu.Lock()
if cs.stopReqBody == nil {
cs.stopReqBody = errStopReqBodyWrite
cerr = bodyCloser.Close()
}
cc.mu.Unlock()
if err == nil {
err = cerr
}
Expand Down
45 changes: 45 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4899,3 +4899,48 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
}
res.Body.Close()
}

type closeChecker struct {
io.ReadCloser
closed chan struct{}
}

func (rc *closeChecker) Close() error {
close(rc.closed)
return rc.ReadCloser.Close()
}

func TestTransportCloseRequestBody(t *testing.T) {
var statusCode int
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
}, optOnlyServer)
defer st.Close()

tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}

for _, status := range []int{200, 401} {
t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
statusCode = status
pr, pw := io.Pipe()
pipeClosed := make(chan struct{})
req, err := http.NewRequest("PUT", "https://dummy.tld/", &closeChecker{pr, pipeClosed})
if err != nil {
t.Fatal(err)
}
res, err := cc.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
pw.Close()
<-pipeClosed
})
}
}

0 comments on commit 553fb77

Please sign in to comment.