Skip to content

Commit

Permalink
http2: close the request body if needed
Browse files Browse the repository at this point in the history
As per client.Do and Request.Body, the transport is responsible to close
the request Body.
If there was an error or non 1xx/2xx status code, the transport will
wait for the body writer to complete. If there is no data available to
read, the body writer will block indefinitely. To prevent this, the body
will be closed if it hasn't already.
If there was a 1xx/2xx status code, the body will be closed eventually.

Updates golang/go#43989

Change-Id: I9a4a5f13658122c562baf915e2c0c8992a023278
Reviewed-on: https://go-review.googlesource.com/c/net/+/323689
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Trust: Alexander Rakoczy <alex@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
  • Loading branch information
fraenkel authored and toothrot committed Aug 25, 2021
1 parent 60bc85c commit e898025
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
59 changes: 29 additions & 30 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,13 @@ func (cs *clientStream) abortRequestBodyWrite(err error) {
}
cc := cs.cc
cc.mu.Lock()
cs.stopReqBody = err
cc.cond.Broadcast()
if cs.stopReqBody == nil {
cs.stopReqBody = err
if cs.req.Body != nil {
cs.req.Body.Close()
}
cc.cond.Broadcast()
}
cc.mu.Unlock()
}

Expand Down Expand Up @@ -1110,40 +1115,28 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
return res, false, nil
}

handleError := func(err error) (*http.Response, bool, error) {
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), err
}

for {
select {
case re := <-readLoopResCh:
return handleReadLoopResponse(re)
case <-respHeaderTimer:
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errTimeout
return handleError(errTimeout)
case <-ctx.Done():
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), ctx.Err()
return handleError(ctx.Err())
case <-req.Cancel:
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
<-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errRequestCanceled
return handleError(errRequestCanceled)
case <-cs.peerReset:
// processResetStream already removed the
// stream from the streams map; no need for
Expand Down Expand Up @@ -1290,7 +1283,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
// Request.Body is closed by the Transport,
// and in multiple cases: server replies <=299 and >299
// while still writing request body
cerr := bodyCloser.Close()
var cerr error
cc.mu.Lock()
if cs.stopReqBody == nil {
cs.stopReqBody = errStopReqBodyWrite
cerr = bodyCloser.Close()
}
cc.mu.Unlock()
if err == nil {
err = cerr
}
Expand Down
45 changes: 45 additions & 0 deletions http2/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4899,3 +4899,48 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) {
}
res.Body.Close()
}

type closeChecker struct {
io.ReadCloser
closed chan struct{}
}

func (rc *closeChecker) Close() error {
close(rc.closed)
return rc.ReadCloser.Close()
}

func TestTransportCloseRequestBody(t *testing.T) {
var statusCode int
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
}, optOnlyServer)
defer st.Close()

tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}

for _, status := range []int{200, 401} {
t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
statusCode = status
pr, pw := io.Pipe()
pipeClosed := make(chan struct{})
req, err := http.NewRequest("PUT", "https://dummy.tld/", &closeChecker{pr, pipeClosed})
if err != nil {
t.Fatal(err)
}
res, err := cc.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
pw.Close()
<-pipeClosed
})
}
}

0 comments on commit e898025

Please sign in to comment.