From c136d0c937afa54dca414a69603bb1570a28879f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 19 Dec 2023 09:01:27 -0800 Subject: [PATCH] quic: avoid panic when PTO expires and implicitly-created streams exist The streams map contains nil entries for implicitly-created streams. (Receiving a packet for stream N implicitly creates all streams of the same type LUCI-TryBot-Result: Go LUCI --- internal/quic/conn_streams.go | 45 ++++++++++++++++++++---------- internal/quic/conn_streams_test.go | 35 +++++++++++++++++++++++ 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index dc82f8b0f..87cfd297e 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -14,8 +14,16 @@ import ( ) type streamsState struct { - queue queue[*Stream] // new, peer-created streams - streams map[streamID]*Stream + queue queue[*Stream] // new, peer-created streams + + // All peer-created streams. + // + // Implicitly created streams are included as an empty entry in the map. + // (For example, if we receive a frame for stream 4, we implicitly create stream 0 and + // insert an empty entry for it to the map.) + // + // The map value is maybeStream rather than *Stream as a reminder that values can be nil. + streams map[streamID]maybeStream // Limits on the number of streams, indexed by streamType. localLimit [streamTypeCount]localStreamLimits @@ -37,8 +45,13 @@ type streamsState struct { queueData streamRing // streams with only flow-controlled frames } +// maybeStream is a possibly nil *Stream. See streamsState.streams. +type maybeStream struct { + s *Stream +} + func (c *Conn) streamsInit() { - c.streams.streams = make(map[streamID]*Stream) + c.streams.streams = make(map[streamID]maybeStream) c.streams.queue = newQueue[*Stream]() c.streams.localLimit[bidiStream].init() c.streams.localLimit[uniStream].init() @@ -52,8 +65,8 @@ func (c *Conn) streamsCleanup() { c.streams.localLimit[bidiStream].connHasClosed() c.streams.localLimit[uniStream].connHasClosed() for _, s := range c.streams.streams { - if s != nil { - s.connHasClosed() + if s.s != nil { + s.s.connHasClosed() } } } @@ -97,7 +110,7 @@ func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, er // Modify c.streams on the conn's loop. if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) { - c.streams.streams[s.id] = s + c.streams.streams[s.id] = maybeStream{s} }); err != nil { return nil, err } @@ -119,7 +132,7 @@ const ( // streamForID returns the stream with the given id. // If the stream does not exist, it returns nil. func (c *Conn) streamForID(id streamID) *Stream { - return c.streams.streams[id] + return c.streams.streams[id].s } // streamForFrame returns the stream with the given id. @@ -144,9 +157,9 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) } } - s, isOpen := c.streams.streams[id] - if s != nil { - return s + ms, isOpen := c.streams.streams[id] + if ms.s != nil { + return ms.s } num := id.num() @@ -183,10 +196,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) // with the same initiator and type and a lower number. // Add a nil entry to the streams map for each implicitly created stream. for n := newStreamID(id.initiator(), id.streamType(), prevOpened); n < id; n += 4 { - c.streams.streams[n] = nil + c.streams.streams[n] = maybeStream{} } - s = newStream(c, id) + s := newStream(c, id) s.inmaxbuf = c.config.maxStreamReadBufferSize() s.inwin = c.config.maxStreamReadBufferSize() if id.streamType() == bidiStream { @@ -196,7 +209,7 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) s.inUnlock() s.outUnlock() - c.streams.streams[id] = s + c.streams.streams[id] = maybeStream{s} c.streams.queue.put(s) return s } @@ -400,7 +413,11 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { c.streams.sendMu.Lock() defer c.streams.sendMu.Unlock() const pto = true - for _, s := range c.streams.streams { + for _, ms := range c.streams.streams { + s := ms.s + if s == nil { + continue + } const pto = true s.ingate.lock() inOK := s.appendInFramesLocked(w, pnum, pto) diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 90f5cb75c..fb9af47eb 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -522,3 +522,38 @@ func TestStreamsCreateConcurrency(t *testing.T) { t.Errorf("accepted %v streams, want %v", got, want) } } + +func TestStreamsPTOWithImplicitStream(t *testing.T) { + ctx := canceledContext() + tc := newTestConn(t, serverSide, permissiveTransportParameters) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + // Peer creates stream 1, and implicitly creates stream 0. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 1), + }) + + // We accept stream 1 and write data to it. + data := []byte("data") + s, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("conn.AcceptStream() = %v, want stream", err) + } + s.Write(data) + s.Flush() + tc.wantFrame("data written to stream", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 1), + data: data, + }) + + // PTO expires, and the data is resent. + const pto = true + tc.triggerLossOrPTO(packetType1RTT, true) + tc.wantFrame("data resent after PTO expires", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 1), + data: data, + }) +}