diff --git a/errors.go b/errors.go index 71f2c463..c789fa33 100644 --- a/errors.go +++ b/errors.go @@ -67,6 +67,11 @@ type StreamError struct { Cause error // optional additional detail } +// errFromPeer is a sentinel error value for StreamError.Cause to +// indicate that the StreamError was sent from the peer over the wire +// and wasn't locally generated in the Transport. +var errFromPeer = errors.New("received from peer") + func streamError(id uint32, code ErrCode) StreamError { return StreamError{StreamID: id, Code: code} } diff --git a/transport.go b/transport.go index b261beb1..dc31cfd7 100644 --- a/transport.go +++ b/transport.go @@ -244,6 +244,7 @@ type ClientConn struct { cond *sync.Cond // hold mu; broadcast on flow/closed changes flow flow // our conn-level flow control quota (cs.flow is per stream) inflow flow // peer's conn-level flow control + doNotReuse bool // whether conn is marked to not be reused for any future requests closing bool closed bool wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back @@ -563,6 +564,10 @@ func canRetryError(err error) bool { return true } if se, ok := err.(StreamError); ok { + if se.Code == ErrCodeProtocol && se.Cause == errFromPeer { + // See golang/go#47635, golang/go#42777 + return true + } return se.Code == ErrCodeRefusedStream } return false @@ -714,6 +719,13 @@ func (cc *ClientConn) healthCheck() { } } +// SetDoNotReuse marks cc as not reusable for future HTTP requests. +func (cc *ClientConn) SetDoNotReuse() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.doNotReuse = true +} + func (cc *ClientConn) setGoAway(f *GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -776,6 +788,7 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) { } st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && + !cc.doNotReuse && int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && !cc.tooIdleLocked() st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest @@ -2419,10 +2432,17 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error { // which closes this, so there // isn't a race. default: - err := streamError(cs.ID, f.ErrCode) - cs.resetErr = err + serr := streamError(cs.ID, f.ErrCode) + if f.ErrCode == ErrCodeProtocol { + rl.cc.SetDoNotReuse() + serr.Cause = errFromPeer + // TODO(bradfitz): increment a varz here, once Transport + // takes an optional interface-typed field that expvar.Map.Add + // implements. + } + cs.resetErr = serr close(cs.peerReset) - cs.bufPipe.CloseWithError(err) + cs.bufPipe.CloseWithError(serr) cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl } return nil diff --git a/transport_test.go b/transport_test.go index 2da7d9de..4412a893 100644 --- a/transport_test.go +++ b/transport_test.go @@ -4944,3 +4944,104 @@ func TestTransportCloseRequestBody(t *testing.T) { }) } } + +// collectClientsConnPool is a ClientConnPool that wraps lower and +// collects what calls were made on it. +type collectClientsConnPool struct { + lower ClientConnPool + + mu sync.Mutex + getErrs int + got []*ClientConn +} + +func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { + cc, err := p.lower.GetClientConn(req, addr) + p.mu.Lock() + defer p.mu.Unlock() + if err != nil { + p.getErrs++ + return nil, err + } + p.got = append(p.got, cc) + return cc, nil +} + +func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { + p.lower.MarkDead(cc) +} + +func TestTransportRetriesOnStreamProtocolError(t *testing.T) { + ct := newClientTester(t) + pool := &collectClientsConnPool{ + lower: &clientConnPool{t: ct.tr}, + } + ct.tr.ConnPool = pool + done := make(chan struct{}) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + const want = "only one dial allowed in test mode" + if got := fmt.Sprint(err); got != want { + t.Errorf("didn't dial again: got %#q; want %#q", got, want) + } + close(done) + ct.sc.Close() + if res != nil { + res.Body.Close() + } + + pool.mu.Lock() + defer pool.mu.Unlock() + if pool.getErrs != 1 { + t.Errorf("pool get errors = %v; want 1", pool.getErrs) + } + if len(pool.got) == 1 { + cc := pool.got[0] + cc.mu.Lock() + if !cc.doNotReuse { + t.Error("ClientConn not marked doNotReuse") + } + cc.mu.Unlock() + } else { + t.Errorf("pool get success = %v; want 1", len(pool.got)) + } + return nil + } + ct.server = func() error { + ct.greet() + var sentErr bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-done: + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + if !sentErr { + sentErr = true + ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) + continue + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + // send headers without Trailer header + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + } + return nil + } + ct.run() +}