diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go index a5b7513f412..9097385e1a6 100644 --- a/internal/transport/controlbuf.go +++ b/internal/transport/controlbuf.go @@ -527,6 +527,9 @@ const minBatchSize = 1000 // As an optimization, to increase the batch size for each flush, loopy yields the processor, once // if the batch size is too low to give stream goroutines a chance to fill it up. func (l *loopyWriter) run() (err error) { + // Always flush the writer before exiting in case there are pending frames + // to be sent. + defer l.framer.writer.Flush() for { it, err := l.cbuf.get(true) if err != nil { @@ -759,7 +762,7 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { return err } } - if l.side == clientSide && l.draining && len(l.estdStreams) == 0 { + if l.draining && len(l.estdStreams) == 0 { return errors.New("finished processing active streams while in draining mode") } return nil @@ -814,7 +817,6 @@ func (l *loopyWriter) goAwayHandler(g *goAway) error { } func (l *loopyWriter) closeConnectionHandler() error { - l.framer.writer.Flush() // Exit loopyWriter entirely by returning an error here. This will lead to // the transport closing the connection, and, ultimately, transport // closure. diff --git a/test/gracefulstop_test.go b/test/gracefulstop_test.go index a5a8448ad2f..15e0611a219 100644 --- a/test/gracefulstop_test.go +++ b/test/gracefulstop_test.go @@ -26,6 +26,7 @@ import ( "testing" "time" + "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" @@ -164,3 +165,53 @@ func (s) TestGracefulStop(t *testing.T) { cancel() wg.Wait() } + +func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) { + // This test ensures that a server closes the connections to its clients + // when the final stream has completed after a GOAWAY. + + handlerCalled := make(chan struct{}) + gracefulStopCalled := make(chan struct{}) + + ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { + close(handlerCalled) // Initiate call to GracefulStop. + <-gracefulStopCalled // Wait for GOAWAYs to be received by the client. + return nil + }} + + te := newTest(t, tcpClearEnv) + te.startServer(ts) + defer te.tearDown() + + te.withServerTester(func(st *serverTester) { + st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false) + + <-handlerCalled // Wait for the server to invoke its handler. + + // Gracefully stop the server. + gracefulStopDone := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(gracefulStopDone) + }() + st.wantGoAway(http2.ErrCodeNo) // Server sends a GOAWAY due to GracefulStop. + pf := st.wantPing() // Server sends a ping to verify client receipt. + st.writePing(true, pf.Data) // Send ping ack to confirm. + st.wantGoAway(http2.ErrCodeNo) // Wait for subsequent GOAWAY to indicate no new stream processing. + + close(gracefulStopCalled) // Unblock server handler. + + fr := st.wantAnyFrame() // Wait for trailer. + hdr, ok := fr.(*http2.MetaHeadersFrame) + if !ok { + t.Fatalf("Received unexpected frame of type (%T) from server: %v; want HEADERS", fr, fr) + } + if !hdr.StreamEnded() { + t.Fatalf("Received unexpected HEADERS frame from server: %v; want END_STREAM set", fr) + } + + st.wantRSTStream(http2.ErrCodeNo) // Server should send RST_STREAM because client did not half-close. + + <-gracefulStopDone // Wait for GracefulStop to return. + }) +} diff --git a/test/servertester.go b/test/servertester.go index bf7bd8b214e..3701a0e094d 100644 --- a/test/servertester.go +++ b/test/servertester.go @@ -138,19 +138,46 @@ func (st *serverTester) writeSettingsAck() { } } +func (st *serverTester) wantGoAway(errCode http2.ErrCode) *http2.GoAwayFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting an RST frame: %v", err) + } + gaf, ok := f.(*http2.GoAwayFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f) + } + if gaf.ErrCode != errCode { + st.t.Fatalf("expected GOAWAY error code '%v', got '%v'", errCode.String(), gaf.ErrCode.String()) + } + return gaf +} + +func (st *serverTester) wantPing() *http2.PingFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting an RST frame: %v", err) + } + pf, ok := f.(*http2.PingFrame) + if !ok { + st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f) + } + return pf +} + func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting an RST frame: %v", err) } - sf, ok := f.(*http2.RSTStreamFrame) + rf, ok := f.(*http2.RSTStreamFrame) if !ok { st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f) } - if sf.ErrCode != errCode { - st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String()) + if rf.ErrCode != errCode { + st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), rf.ErrCode.String()) } - return sf + return rf } func (st *serverTester) wantSettings() *http2.SettingsFrame { diff --git a/test/stream_cleanup_test.go b/test/stream_cleanup_test.go index 83dd68549e9..53298ea372f 100644 --- a/test/stream_cleanup_test.go +++ b/test/stream_cleanup_test.go @@ -46,7 +46,7 @@ func (s) TestStreamCleanup(t *testing.T) { return &testpb.Empty{}, nil }, } - if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil { + if err := ss.Start(nil, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() @@ -79,7 +79,7 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) { }) }, } - if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil { + if err := ss.Start(nil, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() @@ -132,6 +132,6 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) { case <-gracefulStopDone: timer.Stop() case <-timer.C: - t.Fatalf("s.GracefulStop() didn't finish without 1 second after the last RPC") + t.Fatalf("s.GracefulStop() didn't finish within 1 second after the last RPC") } }