diff --git a/clientconn.go b/clientconn.go index 4a408d621692..19763f8eddfa 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1140,10 +1140,15 @@ func (cc *ClientConn) Close() error { <-cc.resolverWrapper.serializer.Done() <-cc.balancerWrapper.serializer.Done() - + var wg sync.WaitGroup for ac := range conns { - ac.tearDown(ErrClientConnClosing) + wg.Add(1) + go func(ac *addrConn) { + defer wg.Done() + ac.tearDown(ErrClientConnClosing) + }(ac) } + wg.Wait() cc.addTraceEvent("deleted") // TraceEvent needs to be called before RemoveEntry, as TraceEvent may add // trace reference to the entity being deleted, and thus prevent it from being diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index ba42e51129ed..62b81885d8ef 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -86,9 +86,9 @@ type http2Client struct { writerDone chan struct{} // sync point to enable testing. // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // that the server sent GoAway on this transport. - goAway chan struct{} - - framer *framer + goAway chan struct{} + keepaliveDone chan struct{} // Closed when the keepalive goroutine exits. + framer *framer // controlBuf delivers all the control related tasks (e.g., window // updates, reset streams, and various settings) to the controller. // Do not access controlBuf with mu held. @@ -335,6 +335,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts readerDone: make(chan struct{}), writerDone: make(chan struct{}), goAway: make(chan struct{}), + keepaliveDone: make(chan struct{}), framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize), fc: &trInFlow{limit: uint32(icwz)}, scheme: scheme, @@ -1029,6 +1030,12 @@ func (t *http2Client) Close(err error) { } t.cancel() t.conn.Close() + // Waits for the reader and keepalive goroutines to exit before returning to + // ensure all resources are cleaned up before Close can return. + <-t.readerDone + if t.keepaliveEnabled { + <-t.keepaliveDone + } channelz.RemoveEntry(t.channelz.ID) var st *status.Status if len(goAwayDebugMessage) > 0 { @@ -1316,11 +1323,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { t.controlBuf.put(pingAck) } -func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { +func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) error { t.mu.Lock() if t.state == closing { t.mu.Unlock() - return + return nil } if f.ErrCode == http2.ErrCodeEnhanceYourCalm && string(f.DebugData()) == "too_many_pings" { // When a client receives a GOAWAY with error code ENHANCE_YOUR_CALM and debug @@ -1332,8 +1339,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { id := f.LastStreamID if id > 0 && id%2 == 0 { t.mu.Unlock() - t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id)) - return + return connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id) } // A client can receive multiple GoAways from the server (see // https://github.com/grpc/grpc-go/issues/1387). The idea is that the first @@ -1350,8 +1356,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) - return + return connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID) } default: t.setGoAwayReason(f) @@ -1375,8 +1380,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.prevGoAwayID = id if len(t.activeStreams) == 0 { t.mu.Unlock() - t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) - return + return connectionErrorf(true, nil, "received goaway and there are no active streams") } streamsToClose := make([]*Stream, 0) @@ -1393,6 +1397,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { for _, stream := range streamsToClose { t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } + return nil } // setGoAwayReason sets the value of t.goAwayReason based @@ -1628,7 +1633,13 @@ func (t *http2Client) readServerPreface() error { // network connection. If the server preface is not read successfully, an // error is pushed to errCh; otherwise errCh is closed with no error. func (t *http2Client) reader(errCh chan<- error) { - defer close(t.readerDone) + var errClose error + defer func() { + close(t.readerDone) + if errClose != nil { + t.Close(errClose) + } + }() if err := t.readServerPreface(); err != nil { errCh <- err @@ -1669,7 +1680,7 @@ func (t *http2Client) reader(errCh chan<- error) { continue } // Transport error. - t.Close(connectionErrorf(true, err, "error reading from server: %v", err)) + errClose = connectionErrorf(true, err, "error reading from server: %v", err) return } switch frame := frame.(type) { @@ -1684,7 +1695,7 @@ func (t *http2Client) reader(errCh chan<- error) { case *http2.PingFrame: t.handlePing(frame) case *http2.GoAwayFrame: - t.handleGoAway(frame) + errClose = t.handleGoAway(frame) case *http2.WindowUpdateFrame: t.handleWindowUpdate(frame) default: @@ -1697,6 +1708,13 @@ func (t *http2Client) reader(errCh chan<- error) { // keepalive running in a separate goroutine makes sure the connection is alive by sending pings. func (t *http2Client) keepalive() { + var err error + defer func() { + close(t.keepaliveDone) + if err != nil { + t.Close(err) + } + }() p := &ping{data: [8]byte{}} // True iff a ping has been sent, and no data has been received since then. outstandingPing := false @@ -1720,7 +1738,7 @@ func (t *http2Client) keepalive() { continue } if outstandingPing && timeoutLeft <= 0 { - t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) + err = connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout") return } t.mu.Lock() diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index 393a4540396f..ad377e6b241b 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -44,6 +44,7 @@ import ( ) const defaultTestTimeout = 10 * time.Second +const defaultTestShortTimeout = 10 * time.Millisecond // TestMaxConnectionIdle tests that a server will send GoAway to an idle // client. An idle client is one who doesn't make any RPC calls for a duration diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 65efb30c4bb6..4752c785b59d 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -2781,6 +2781,89 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { } } +// readHangingConn is a wrapper around net.Conn that makes the Read() hang when +// Close() is called. +type readHangingConn struct { + net.Conn + readHangConn chan struct{} // Read() hangs until this channel is closed by Close(). + closed *atomic.Bool // Set to true when Close() is called. +} + +func (hc *readHangingConn) Read(b []byte) (n int, err error) { + n, err = hc.Conn.Read(b) + if hc.closed.Load() { + <-hc.readHangConn // hang the read till we want + } + return n, err +} + +func (hc *readHangingConn) Close() error { + hc.closed.Store(true) + return hc.Conn.Close() +} + +// Tests that closing a client transport does not return until the reader +// goroutine exits. +func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + server := setUpServerOnly(t, 0, &ServerConfig{}, normal) + defer server.stop() + addr := resolver.Address{Addr: "localhost:" + server.port} + + isReaderHanging := &atomic.Bool{} + readHangConn := make(chan struct{}) + copts := ConnectOptions{ + Dialer: func(_ context.Context, addr string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + return &readHangingConn{Conn: conn, readHangConn: readHangConn, closed: isReaderHanging}, nil + }, + ChannelzParent: channelzSubChannel(t), + } + + // Create a client transport with a custom dialer that hangs the Read() + // after Close(). + ct, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) + if err != nil { + t.Fatalf("Failed to create transport: %v", err) + } + + if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + // Closing the client transport will result in the underlying net.Conn being + // closed, which will result in readHangingConn.Read() to hang. This will + // stall the exit of the reader goroutine, and will stall client + // transport's Close from returning. + transportClosed := make(chan struct{}) + go func() { + ct.Close(errors.New("manually closed by client")) + close(transportClosed) + }() + + // Wait for a short duration and ensure that the client transport's Close() + // does not return. + select { + case <-transportClosed: + t.Fatal("Transport closed before reader completed") + case <-time.After(defaultTestShortTimeout): + } + + // Closing the channel will unblock the reader goroutine and will ensure + // that the client transport's Close() returns. + close(readHangConn) + select { + case <-transportClosed: + case <-time.After(defaultTestTimeout): + t.Fatal("Timeout when waiting for transport to close") + } +} + // hangingConn is a net.Conn wrapper for testing, simulating hanging connections // after a GOAWAY frame is sent, of which Write operations pause until explicitly // signaled or a timeout occurs.