From 97384c11dd0db63357820b2cfcb44c40fbc3116a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 29 Aug 2023 13:18:00 -0700 Subject: [PATCH] quic: remove streams from the conn when done When a stream has been fully shut down--the peer has closed its end and acked every frame we will send for it--remove it from the Conn's set of active streams. We do the actual removal on the conn's loop, so stream cleanup can access conn state without worrying about locking. For golang/go#58547 Change-Id: Id9715693649929b07d303f0c4b3a782d135f0326 Reviewed-on: https://go-review.googlesource.com/c/net/+/524296 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/atomic_bits.go | 33 +++++++ internal/quic/conn_streams.go | 62 +++++++++---- internal/quic/conn_streams_test.go | 89 ++++++++++++++++++ internal/quic/conn_test.go | 2 + internal/quic/stream.go | 140 +++++++++++++++++++++++------ internal/quic/stream_test.go | 33 +++++++ 6 files changed, 315 insertions(+), 44 deletions(-) create mode 100644 internal/quic/atomic_bits.go diff --git a/internal/quic/atomic_bits.go b/internal/quic/atomic_bits.go new file mode 100644 index 0000000000..e1e2594d15 --- /dev/null +++ b/internal/quic/atomic_bits.go @@ -0,0 +1,33 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import "sync/atomic" + +// atomicBits is an atomic uint32 that supports setting individual bits. +type atomicBits[T ~uint32] struct { + bits atomic.Uint32 +} + +// set sets the bits in mask to the corresponding bits in v. +// It returns the new value. +func (a *atomicBits[T]) set(v, mask T) T { + if v&^mask != 0 { + panic("BUG: bits in v are not in mask") + } + for { + o := a.bits.Load() + n := (o &^ uint32(mask)) | uint32(v) + if a.bits.CompareAndSwap(o, n) { + return T(n) + } + } +} + +func (a *atomicBits[T]) load() T { + return T(a.bits.Load()) +} diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index dd35e34cf6..0ede284e23 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -185,24 +185,46 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) for { s := c.streams.sendHead const pto = false - if !s.appendInFrames(w, pnum, pto) { - return false + + state := s.state.load() + if state&streamInSend != 0 { + s.ingate.lock() + ok := s.appendInFramesLocked(w, pnum, pto) + state = s.inUnlockNoQueue() + if !ok { + return false + } } - avail := w.avail() - if !s.appendOutFrames(w, pnum, pto) { - // We've sent some data for this stream, but it still has more to send. - // If the stream got a reasonable chance to put data in a packet, - // advance sendHead to the next stream in line, to avoid starvation. - // We'll come back to this stream after going through the others. - // - // If the packet was already mostly out of space, leave sendHead alone - // and come back to this stream again on the next packet. - if avail > 512 { - c.streams.sendHead = s.next - c.streams.sendTail = s + + if state&streamOutSend != 0 { + avail := w.avail() + s.outgate.lock() + ok := s.appendOutFramesLocked(w, pnum, pto) + state = s.outUnlockNoQueue() + if !ok { + // We've sent some data for this stream, but it still has more to send. + // If the stream got a reasonable chance to put data in a packet, + // advance sendHead to the next stream in line, to avoid starvation. + // We'll come back to this stream after going through the others. + // + // If the packet was already mostly out of space, leave sendHead alone + // and come back to this stream again on the next packet. + if avail > 512 { + c.streams.sendHead = s.next + c.streams.sendTail = s + } + return false } - return false } + + if state == streamInDone|streamOutDone { + // Stream is finished, remove it from the conn. + s.state.set(streamConnRemoved, streamConnRemoved) + delete(c.streams.streams, s.id) + + // TODO: Provide the peer with additional stream quota (MAX_STREAMS). + } + next := s.next s.next = nil if (next == s) != (s == c.streams.sendTail) { @@ -231,10 +253,16 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { defer c.streams.sendMu.Unlock() for _, s := range c.streams.streams { const pto = true - if !s.appendInFrames(w, pnum, pto) { + s.ingate.lock() + inOK := s.appendInFramesLocked(w, pnum, pto) + s.inUnlockNoQueue() + if !inOK { return false } - if !s.appendOutFrames(w, pnum, pto) { + s.outgate.lock() + outOK := s.appendOutFramesLocked(w, pnum, pto) + s.outUnlockNoQueue() + if !outOK { return false } } diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 877dbb94fc..9bbc994b11 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -8,6 +8,8 @@ package quic import ( "context" + "fmt" + "io" "testing" ) @@ -253,3 +255,90 @@ func TestStreamsWriteQueueFairness(t *testing.T) { } } } + +func TestStreamsShutdown(t *testing.T) { + // These tests verify that a stream is removed from the Conn's map of live streams + // after it is fully shut down. + // + // Each case consists of a setup step, after which one stream should exist, + // and a shutdown step, after which no streams should remain in the Conn. + for _, test := range []struct { + name string + side streamSide + styp streamType + setup func(*testing.T, *testConn, *Stream) + shutdown func(*testing.T, *testConn, *Stream) + }{{ + name: "closed", + side: localStream, + styp: uniStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + s.CloseContext(canceledContext()) + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeAckForAll() + }, + }, { + name: "local close", + side: localStream, + styp: bidiStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: s.id, + }) + s.CloseContext(canceledContext()) + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeAckForAll() + }, + }, { + name: "remote reset", + side: localStream, + styp: bidiStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + s.CloseContext(canceledContext()) + tc.wantIdle("all frames after CloseContext are ignored") + tc.writeAckForAll() + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: s.id, + }) + }, + }, { + name: "local close", + side: remoteStream, + styp: uniStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + ctx := canceledContext() + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + fin: true, + }) + if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF { + t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err) + } + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + s.CloseRead() + }, + }} { + name := fmt.Sprintf("%v/%v/%v", test.side, test.styp, test.name) + t.Run(name, func(t *testing.T) { + tc, s := newTestConnAndStream(t, serverSide, test.side, test.styp, + permissiveTransportParameters) + tc.ignoreFrame(frameTypeStreamBase) + tc.ignoreFrame(frameTypeStopSending) + test.setup(t, tc, s) + tc.wantIdle("conn should be idle after setup") + if got, want := len(tc.conn.streams.streams), 1; got != want { + t.Fatalf("after setup: %v streams in Conn's map; want %v", got, want) + } + test.shutdown(t, tc, s) + tc.wantIdle("conn should be idle after shutdown") + if got, want := len(tc.conn.streams.streams), 0; got != want { + t.Fatalf("after shutdown: %v streams in Conn's map; want %v", got, want) + } + }) + } +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index d8c44558dc..ea720d5754 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -394,6 +394,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the // last one received. func (tc *testConn) writeAckForAll() { + tc.t.Helper() if tc.lastPacket == nil { return } @@ -405,6 +406,7 @@ func (tc *testConn) writeAckForAll() { // writeAckForLatest sends the Conn a datagram containing an ack for the // most recent packet received. func (tc *testConn) writeAckForLatest() { + tc.t.Helper() if tc.lastPacket == nil { return } diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 1033cbb401..2dbf4461ba 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -49,9 +49,38 @@ type Stream struct { outresetcode uint64 // reset code to send in RESET_STREAM outdone chan struct{} // closed when all data sent + // Atomic stream state bits. + // + // These bits provide a fast way to coordinate between the + // send and receive sides of the stream, and the conn's loop. + // + // streamIn* bits must be set with ingate held. + // streamOut* bits must be set with outgate held. + // streamConn* bits are set by the conn's loop. + state atomicBits[streamState] + prev, next *Stream // guarded by streamsState.sendMu } +type streamState uint32 + +const ( + // streamInSend and streamOutSend are set when there are + // frames to send for the inbound or outbound sides of the stream. + // For example, MAX_STREAM_DATA or STREAM_DATA_BLOCKED. + streamInSend = streamState(1 << iota) + streamOutSend + + // streamInDone and streamOutDone are set when the inbound or outbound + // sides of the stream are finished. When both are set, the stream + // can be removed from the Conn and forgotten. + streamInDone + streamOutDone + + // streamConnRemoved is set when the stream has been removed from the conn. + streamConnRemoved +) + // newStream returns a new stream. // // The stream's ingate and outgate are locked. @@ -289,15 +318,34 @@ func (s *Stream) CloseWrite() { // that the stream was terminated abruptly. // Any blocked writes will be unblocked and return errors. // -// Reset sends the application protocol error code to the peer. +// Reset sends the application protocol error code, which must be +// less than 2^62, to the peer. // It does not wait for the peer to acknowledge receipt of the error. // Use CloseContext to wait for the peer's acknowledgement. +// +// Reset does not affect reads. +// Use CloseRead to abort reads on the stream. func (s *Stream) Reset(code uint64) { + const userClosed = true + s.resetInternal(code, userClosed) +} + +func (s *Stream) resetInternal(code uint64, userClosed bool) { s.outgate.lock() defer s.outUnlock() + if s.IsReadOnly() { + return + } + if userClosed { + // Mark that the user closed the stream. + s.outclosed.set() + } if s.outreset.isSet() { return } + if code > maxVarint { + code = maxVarint + } // We could check here to see if the stream is closed and the // peer has acked all the data and the FIN, but sending an // extra RESET_STREAM in this case is harmless. @@ -310,44 +358,67 @@ func (s *Stream) Reset(code uint64) { // inUnlock unlocks s.ingate. // It sets the gate condition if reads from s will not block. -// If s has receive-related frames to write, it notifies the Conn. +// If s has receive-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. func (s *Stream) inUnlock() { - if s.inUnlockNoQueue() { + state := s.inUnlockNoQueue() + if state&streamInSend != 0 || state == streamInDone|streamOutDone { s.conn.queueStreamForSend(s) } } // inUnlockNoQueue is inUnlock, // but reports whether s has frames to write rather than notifying the Conn. -func (s *Stream) inUnlockNoQueue() (shouldSend bool) { +func (s *Stream) inUnlockNoQueue() streamState { canRead := s.inset.contains(s.in.start) || // data available to read s.insize == s.in.start || // at EOF s.inresetcode != -1 || // reset by peer s.inclosed.isSet() // closed locally defer s.ingate.unlock(canRead) - return s.insendmax.shouldSend() || // STREAM_MAX_DATA - s.inclosed.shouldSend() // STOP_SENDING + var state streamState + switch { + case s.IsWriteOnly(): + state = streamInDone + case s.inresetcode != -1: // reset by peer + fallthrough + case s.in.start == s.insize: // all data received and read + // We don't increase MAX_STREAMS until the user calls ReadClose or Close, + // so the receive side is not finished until inclosed is set. + if s.inclosed.isSet() { + state = streamInDone + } + case s.insendmax.shouldSend(): // STREAM_MAX_DATA + state = streamInSend + case s.inclosed.shouldSend(): // STOP_SENDING + state = streamInSend + } + const mask = streamInDone | streamInSend + return s.state.set(state, mask) } // outUnlock unlocks s.outgate. // It sets the gate condition if writes to s will not block. -// If s has send-related frames to write, it notifies the Conn. +// If s has send-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. func (s *Stream) outUnlock() { - if s.outUnlockNoQueue() { + state := s.outUnlockNoQueue() + if state&streamOutSend != 0 || state == streamInDone|streamOutDone { s.conn.queueStreamForSend(s) } } // outUnlockNoQueue is outUnlock, // but reports whether s has frames to write rather than notifying the Conn. -func (s *Stream) outUnlockNoQueue() (shouldSend bool) { +func (s *Stream) outUnlockNoQueue() streamState { isDone := s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) || // all data acked s.outreset.isSet() // reset locally if isDone { select { case <-s.outdone: default: - close(s.outdone) + if !s.IsReadOnly() { + close(s.outdone) + } } } lim := min(s.out.start+s.outmaxbuf, s.outwin) @@ -355,14 +426,32 @@ func (s *Stream) outUnlockNoQueue() (shouldSend bool) { s.outclosed.isSet() || // closed locally s.outreset.isSet() // reset locally defer s.outgate.unlock(canWrite) - if s.outreset.isSet() { - // If the stream is reset locally, the only frame we'll send is RESET_STREAM. - return s.outreset.shouldSend() - } - return len(s.outunsent) > 0 || // STREAM frame with data - s.outclosed.shouldSend() || // STREAM frame with FIN bit - s.outopened.shouldSend() || // STREAM frame with no data - s.outblocked.shouldSend() // STREAM_DATA_BLOCKED + var state streamState + switch { + case s.IsReadOnly(): + state = streamOutDone + case s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end): // all data sent and acked + fallthrough + case s.outreset.isReceived(): // RESET_STREAM sent and acked + // We don't increase MAX_STREAMS until the user calls WriteClose or Close, + // so the send side is not finished until outclosed is set. + if s.outclosed.isSet() { + state = streamOutDone + } + case s.outreset.shouldSend(): // RESET_STREAM + state = streamOutSend + case s.outreset.isSet(): // RESET_STREAM sent but not acknowledged + case len(s.outunsent) > 0: // STREAM frame with data + state = streamOutSend + case s.outclosed.shouldSend(): // STREAM frame with FIN bit + state = streamOutSend + case s.outopened.shouldSend(): // STREAM frame with no data + state = streamOutSend + case s.outblocked.shouldSend(): // STREAM_DATA_BLOCKED + state = streamOutSend + } + const mask = streamOutDone | streamOutSend + return s.state.set(state, mask) } // handleData handles data received in a STREAM frame. @@ -431,7 +520,8 @@ func (s *Stream) checkStreamBounds(end int64, fin bool) error { func (s *Stream) handleStopSending(code uint64) error { // Peer requests that we reset this stream. // https://www.rfc-editor.org/rfc/rfc9000#section-3.5-4 - s.Reset(code) + const userReset = false + s.resetInternal(code, userReset) return nil } @@ -504,14 +594,12 @@ func (s *Stream) ackOrLossData(pnum packetNumber, start, end int64, fin bool, fa } } -// appendInFrames appends STOP_SENDING and MAX_STREAM_DATA frames +// appendInFramesLocked appends STOP_SENDING and MAX_STREAM_DATA frames // to the current packet. // // It returns true if no more frames need appending, // false if not everything fit in the current packet. -func (s *Stream) appendInFrames(w *packetWriter, pnum packetNumber, pto bool) bool { - s.ingate.lock() - defer s.inUnlockNoQueue() +func (s *Stream) appendInFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { if s.inclosed.shouldSendPTO(pto) { // We don't currently have an API for setting the error code. // Just send zero. @@ -534,14 +622,12 @@ func (s *Stream) appendInFrames(w *packetWriter, pnum packetNumber, pto bool) bo return true } -// appendOutFrames appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames +// appendOutFramesLocked appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames // to the current packet. // // It returns true if no more frames need appending, // false if not everything fit in the current packet. -func (s *Stream) appendOutFrames(w *packetWriter, pnum packetNumber, pto bool) bool { - s.outgate.lock() - defer s.outUnlockNoQueue() +func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { if s.outreset.isSet() { // RESET_STREAM if s.outreset.shouldSendPTO(pto) { diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 79377c6a4a..e22e0432ef 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -1111,6 +1111,24 @@ func TestStreamPeerResetFollowedByData(t *testing.T) { }) } +func TestStreamResetInvalidCode(t *testing.T) { + tc, s := newTestConnAndLocalStream(t, serverSide, uniStream) + s.Reset(1 << 62) + tc.wantFrame("reset with invalid code sends a RESET_STREAM anyway", + packetType1RTT, debugFrameResetStream{ + id: s.id, + // The code we send here isn't specified, + // so this could really be any value. + code: (1 << 62) - 1, + }) +} + +func TestStreamResetReceiveOnly(t *testing.T) { + tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream) + s.Reset(0) + tc.wantIdle("resetting a receive-only stream has no effect") +} + func TestStreamPeerStopSendingForActiveStream(t *testing.T) { // "An endpoint that receives a STOP_SENDING frame MUST send a RESET_STREAM frame if // the stream is in the "Ready" or "Send" state." @@ -1145,6 +1163,21 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) { }) } +type streamSide string + +const ( + localStream = streamSide("local") + remoteStream = streamSide("remote") +) + +func newTestConnAndStream(t *testing.T, side connSide, sside streamSide, styp streamType, opts ...any) (*testConn, *Stream) { + if sside == localStream { + return newTestConnAndLocalStream(t, side, styp, opts...) + } else { + return newTestConnAndRemoteStream(t, side, styp, opts...) + } +} + func newTestConnAndLocalStream(t *testing.T, side connSide, styp streamType, opts ...any) (*testConn, *Stream) { t.Helper() ctx := canceledContext()