From b1d7f56b81b7902d871111b82dec6ba45f854ede Mon Sep 17 00:00:00 2001 From: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed, 21 Sep 2022 14:35:08 -0400 Subject: [PATCH] transport: Fix deadlock in transport caused by GOAWAY race with new stream creation (#5652) * transport: Fix deadlock in transport caused by GOAWAY race with new stream creation --- internal/transport/http2_client.go | 23 ++++-- test/clienttester.go | 109 +++++++++++++++++++++++++++++ test/end2end_test.go | 56 ++++++++++++++- 3 files changed, 181 insertions(+), 7 deletions(-) create mode 100644 test/clienttester.go diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 53643fa97477..5c2f35b24e75 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1232,18 +1232,29 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { if upperLimit == 0 { // This is the first GoAway Frame. upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID. } + + t.prevGoAwayID = id + if len(t.activeStreams) == 0 { + t.mu.Unlock() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + return + } + + streamsToClose := make([]*Stream, 0) for streamID, stream := range t.activeStreams { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. - atomic.StoreUint32(&stream.unprocessed, 1) - t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) + if streamID > id && streamID <= upperLimit { + atomic.StoreUint32(&stream.unprocessed, 1) + streamsToClose = append(streamsToClose, stream) + } } } - t.prevGoAwayID = id - active := len(t.activeStreams) t.mu.Unlock() - if active == 0 { - t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + // Called outside t.mu because closeStream can take controlBuf's mu, which + // could induce deadlock and is not allowed. + for _, stream := range streamsToClose { + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } diff --git a/test/clienttester.go b/test/clienttester.go new file mode 100644 index 000000000000..7e223091164d --- /dev/null +++ b/test/clienttester.go @@ -0,0 +1,109 @@ +/* + * Copyright 2022 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "bytes" + "io" + "net" + "testing" + + "golang.org/x/net/http2" +) + +var ( + clientPreface = []byte(http2.ClientPreface) +) + +func newClientTester(t *testing.T, conn net.Conn) *clientTester { + ct := &clientTester{ + t: t, + conn: conn, + } + ct.fr = http2.NewFramer(conn, conn) + ct.greet() + return ct +} + +type clientTester struct { + t *testing.T + conn net.Conn + fr *http2.Framer +} + +// greet() performs the necessary steps for http2 connection establishment on +// the server side. +func (ct *clientTester) greet() { + ct.wantClientPreface() + ct.wantSettingsFrame() + ct.writeSettingsFrame() + ct.writeSettingsAck() + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + ct.t.Errorf("error reading frame from client side: %v", err) + } + switch f := f.(type) { + case *http2.SettingsFrame: + if f.IsAck() { // HTTP/2 handshake completed. + return + } + default: + ct.t.Errorf("during greet, unexpected frame type %T", f) + } + } +} + +func (ct *clientTester) wantClientPreface() { + preface := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(ct.conn, preface); err != nil { + ct.t.Errorf("Error at server-side while reading preface from client. Err: %v", err) + } + if !bytes.Equal(preface, clientPreface) { + ct.t.Errorf("received bogus greeting from client %q", preface) + } +} + +func (ct *clientTester) wantSettingsFrame() { + frame, err := ct.fr.ReadFrame() + if err != nil { + ct.t.Errorf("error reading initial settings frame from client: %v", err) + } + _, ok := frame.(*http2.SettingsFrame) + if !ok { + ct.t.Errorf("initial frame sent from client is not a settings frame, type %T", frame) + } +} + +func (ct *clientTester) writeSettingsFrame() { + if err := ct.fr.WriteSettings(); err != nil { + ct.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) + } +} + +func (ct *clientTester) writeSettingsAck() { + if err := ct.fr.WriteSettingsAck(); err != nil { + ct.t.Fatalf("Error writing ACK of client's SETTINGS: %v", err) + } +} + +func (ct *clientTester) writeGoAway(maxStreamID uint32, code http2.ErrCode, debugData []byte) { + if err := ct.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + ct.t.Fatalf("Error writing GOAWAY: %v", err) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 8a4f11515675..725bcdb641eb 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7407,7 +7407,6 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { return } writer.Flush() // necessary since client is expecting preface before declaring connection fully setup. - var sid uint32 // Loop until conn is closed and framer returns io.EOF for requestNum := 0; ; requestNum = (requestNum + 1) % len(s.responses) { @@ -8130,3 +8129,58 @@ func (s) TestRecvWhileReturningStatus(t *testing.T) { } } } + +// TestGoAwayStreamIDSmallerThanCreatedStreams tests the scenario where a server +// sends a goaway with a stream id that is smaller than some created streams on +// the client, while the client is simultaneously creating new streams. This +// should not induce a deadlock. +func (s) TestGoAwayStreamIDSmallerThanCreatedStreams(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + + ctCh := testutils.NewChannel() + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("error in lis.Accept(): %v", err) + } + ct := newClientTester(t, conn) + ctCh.Send(ct) + }() + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + val, err := ctCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for client transport (should be given after http2 creation)") + } + ct := val.(*clientTester) + + tc := testpb.NewTestServiceClient(cc) + someStreamsCreated := grpcsync.NewEvent() + goAwayWritten := grpcsync.NewEvent() + go func() { + for i := 0; i < 20; i++ { + if i == 10 { + <-goAwayWritten.Done() + } + tc.FullDuplexCall(ctx) + if i == 4 { + someStreamsCreated.Fire() + } + } + }() + + <-someStreamsCreated.Done() + ct.writeGoAway(1, http2.ErrCodeNo, []byte{}) + goAwayWritten.Fire() +}