From 434956a1a8671fa67f9ec468cda2fd83937227b7 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 3 Nov 2023 16:43:48 -0700 Subject: [PATCH] quic: include more detail in connection close errors When closing a connection with an error, include a reason string in the CONNECTION_CLOSE frame as well as the error code, when the code isn't sufficient to explain the error. Change-Id: I055a4e11b222e87d1ff01d8c45fcb7cc17fe4196 Reviewed-on: https://go-review.googlesource.com/c/net/+/539342 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- internal/quic/conn_close.go | 4 +- internal/quic/conn_flow.go | 5 ++- internal/quic/conn_id.go | 61 ++++++++++++++++++++---------- internal/quic/conn_recv.go | 40 ++++++++++++++++---- internal/quic/conn_send.go | 2 +- internal/quic/conn_streams.go | 10 ++++- internal/quic/conn_test.go | 44 ++++++++++++++++++++- internal/quic/crypto_stream.go | 5 ++- internal/quic/errors.go | 10 ++++- internal/quic/listener.go | 2 +- internal/quic/packet_protection.go | 2 +- internal/quic/stream.go | 20 ++++++++-- internal/quic/stream_limits.go | 5 ++- internal/quic/stream_test.go | 3 +- internal/quic/transport_params.go | 26 ++++++------- 15 files changed, 178 insertions(+), 61 deletions(-) diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go index b8b86fd6f..daf425b76 100644 --- a/internal/quic/conn_close.go +++ b/internal/quic/conn_close.go @@ -156,7 +156,7 @@ func (c *Conn) enterDraining(err error) { if c.isDraining() { return } - if e, ok := c.lifetime.localErr.(localTransportError); ok && transportError(e) != errNo { + if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo { // If we've terminated the connection due to a peer protocol violation, // record the final error on the connection as our reason for termination. c.lifetime.finalErr = c.lifetime.localErr @@ -220,7 +220,7 @@ func (c *Conn) Wait(ctx context.Context) error { // Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text. func (c *Conn) Abort(err error) { if err == nil { - err = localTransportError(errNo) + err = localTransportError{code: errNo} } c.sendMsg(func(now time.Time, c *Conn) { c.abort(now, err) diff --git a/internal/quic/conn_flow.go b/internal/quic/conn_flow.go index 4f1ab6eaf..8b69ef7db 100644 --- a/internal/quic/conn_flow.go +++ b/internal/quic/conn_flow.go @@ -90,7 +90,10 @@ func (c *Conn) shouldUpdateFlowControl(credit int64) bool { func (c *Conn) handleStreamBytesReceived(n int64) error { c.streams.inflow.usedLimit += n if c.streams.inflow.usedLimit > c.streams.inflow.sentLimit { - return localTransportError(errFlowControl) + return localTransportError{ + code: errFlowControl, + reason: "stream exceeded flow control limit", + } } return nil } diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index b77ad8edf..439c22123 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -210,25 +210,40 @@ func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p trans // the transient remote connection ID we chose (client) // or is empty (server). if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) { - return localTransportError(errTransportParameter) + return localTransportError{ + code: errTransportParameter, + reason: "original_destination_connection_id mismatch", + } } s.originalDstConnID = nil // we have no further need for this // Verify retry_source_connection_id matches the value from // the server's Retry packet (when one was sent), or is empty. if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) { - return localTransportError(errTransportParameter) + return localTransportError{ + code: errTransportParameter, + reason: "retry_source_connection_id mismatch", + } } s.retrySrcConnID = nil // we have no further need for this // Verify initial_source_connection_id matches the first remote connection ID. if len(s.remote) == 0 || s.remote[0].seq != 0 { - return localTransportError(errInternal) + return localTransportError{ + code: errInternal, + reason: "remote connection id missing", + } } if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) { - return localTransportError(errTransportParameter) + return localTransportError{ + code: errTransportParameter, + reason: "initial_source_connection_id mismatch", + } } if len(p.statelessResetToken) > 0 { if c.side == serverSide { - return localTransportError(errTransportParameter) + return localTransportError{ + code: errTransportParameter, + reason: "client sent stateless_reset_token", + } } token := statelessResetToken(p.statelessResetToken) s.remote[0].resetToken = token @@ -255,17 +270,6 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) }, } } - case ptype == packetTypeInitial && c.side == serverSide: - if len(s.remote) == 0 { - // We're a server connection processing the first Initial packet - // from the client. Set the client's connection ID. - s.remote = append(s.remote, remoteConnID{ - connID: connID{ - seq: 0, - cid: cloneBytes(srcConnID), - }, - }) - } case ptype == packetTypeHandshake && c.side == serverSide: if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired { // We're a server connection processing the first Handshake packet from @@ -294,7 +298,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re // Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID // frame as a connection error of type PROTOCOL_VIOLATION." // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6 - return localTransportError(errProtocolViolation) + return localTransportError{ + code: errProtocolViolation, + reason: "NEW_CONNECTION_ID from peer with zero-length DCID", + } } if retire > s.retireRemotePriorTo { @@ -316,7 +323,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re } if rcid.seq == seq { if !bytes.Equal(rcid.cid, cid) { - return localTransportError(errProtocolViolation) + return localTransportError{ + code: errProtocolViolation, + reason: "NEW_CONNECTION_ID does not match prior id", + } } have = true // yes, we've seen this sequence number } @@ -350,7 +360,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re // Retired connection IDs (including newly-retired ones) do not count // against the limit. // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5 - return localTransportError(errConnectionIDLimit) + return localTransportError{ + code: errConnectionIDLimit, + reason: "active_connection_id_limit exceeded", + } } // "An endpoint SHOULD limit the number of connection IDs it has retired locally @@ -360,7 +373,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re // Set a limit of four times the active_connection_id_limit for // the total number of remote connection IDs we keep state for locally. if len(s.remote) > 4*activeConnIDLimit { - return localTransportError(errConnectionIDLimit) + return localTransportError{ + code: errConnectionIDLimit, + reason: "too many unacknowledged RETIRE_CONNECTION_ID frames", + } } return nil @@ -375,7 +391,10 @@ func (s *connIDState) retireRemote(rcid *remoteConnID) { func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { if seq >= s.nextLocalSeq { - return localTransportError(errProtocolViolation) + return localTransportError{ + code: errProtocolViolation, + reason: "RETIRE_CONNECTION_ID for unissued sequence number", + } } for i := range s.local { if s.local[i].seq == seq { diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index e966b7ef5..8fa3a3906 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -79,12 +79,18 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa if buf[0]&reservedLongBits != 0 { // Reserved header bits must be 0. // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 - c.abort(now, localTransportError(errProtocolViolation)) + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "reserved header bits are not zero", + }) return -1 } if p.version != quicVersion1 { // The peer has changed versions on us mid-handshake? - c.abort(now, localTransportError(errProtocolViolation)) + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "protocol version changed during handshake", + }) return -1 } @@ -129,7 +135,10 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int { if buf[0]&reserved1RTTBits != 0 { // Reserved header bits must be 0. // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1 - c.abort(now, localTransportError(errProtocolViolation)) + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "reserved header bits are not zero", + }) return -1 } @@ -222,7 +231,10 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, // "An endpoint MUST treat receipt of a packet containing no frames // as a connection error of type PROTOCOL_VIOLATION." // https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3 - c.abort(now, localTransportError(errProtocolViolation)) + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "packet contains no frames", + }) return false } // frameOK verifies that ptype is one of the packets in mask. @@ -232,7 +244,10 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, // that is not permitted as a connection error of type // PROTOCOL_VIOLATION." // https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3 - c.abort(now, localTransportError(errProtocolViolation)) + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "frame not allowed in packet", + }) return false } return true @@ -347,7 +362,10 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, n = c.handleHandshakeDoneFrame(now, space, payload) } if n < 0 { - c.abort(now, localTransportError(errFrameEncoding)) + c.abort(now, localTransportError{ + code: errFrameEncoding, + reason: "frame encoding error", + }) return false } payload = payload[n:] @@ -360,7 +378,10 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) { if end > c.loss.nextNumber(space) { // Acknowledgement of a packet we never sent. - c.abort(now, localTransportError(errProtocolViolation)) + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "acknowledgement for unsent packet", + }) return } c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss) @@ -521,7 +542,10 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa if c.side == serverSide { // Clients should never send HANDSHAKE_DONE. // https://www.rfc-editor.org/rfc/rfc9000#section-19.20-4 - c.abort(now, localTransportError(errProtocolViolation)) + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "client sent HANDSHAKE_DONE", + }) return -1 } if !c.isClosingOrDraining() { diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 64e5d7548..22e780479 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -328,7 +328,7 @@ func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err c.lifetime.connCloseSentTime = now switch e := err.(type) { case localTransportError: - c.w.appendConnectionCloseTransportFrame(transportError(e), 0, "") + c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason) case *ApplicationError: if space != appDataSpace { // "CONNECTION_CLOSE frames signaling application errors (type 0x1d) diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index a0793297e..83ab5554c 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -127,7 +127,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) if (id.initiator() == c.side) != (ftype == sendStream) { // Received an invalid frame for unidirectional stream. // For example, a RESET_STREAM frame for a send-only stream. - c.abort(now, localTransportError(errStreamState)) + c.abort(now, localTransportError{ + code: errStreamState, + reason: "invalid frame for unidirectional stream", + }) return nil } } @@ -148,7 +151,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) } // Received a frame for a stream that should be originated by us, // but which we never created. - c.abort(now, localTransportError(errStreamState)) + c.abort(now, localTransportError{ + code: errStreamState, + reason: "received frame for unknown stream", + }) return nil } else { // if isOpen, this is a stream that was implicitly opened by a diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 248be9641..c70c58ef0 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -594,6 +594,20 @@ func (tc *testConn) wantDatagram(expectation string, want *testDatagram) { } } +func datagramEqual(a, b *testDatagram) bool { + if a.paddedSize != b.paddedSize || + a.addr != b.addr || + len(a.packets) != len(b.packets) { + return false + } + for i := range a.packets { + if !packetEqual(a.packets[i], b.packets[i]) { + return false + } + } + return true +} + // wantPacket indicates that we expect the Conn to send a packet. func (tc *testConn) wantPacket(expectation string, want *testPacket) { tc.t.Helper() @@ -603,6 +617,25 @@ func (tc *testConn) wantPacket(expectation string, want *testPacket) { } } +func packetEqual(a, b *testPacket) bool { + ac := *a + ac.frames = nil + bc := *b + bc.frames = nil + if !reflect.DeepEqual(ac, bc) { + return false + } + if len(a.frames) != len(b.frames) { + return false + } + for i := range a.frames { + if !frameEqual(a.frames[i], b.frames[i]) { + return false + } + } + return true +} + // wantFrame indicates that we expect the Conn to send a frame. func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) { tc.t.Helper() @@ -613,11 +646,20 @@ func (tc *testConn) wantFrame(expectation string, wantType packetType, want debu if gotType != wantType { tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) } - if !reflect.DeepEqual(got, want) { + if !frameEqual(got, want) { tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want) } } +func frameEqual(a, b debugFrame) bool { + switch af := a.(type) { + case debugFrameConnectionCloseTransport: + bf, ok := b.(debugFrameConnectionCloseTransport) + return ok && af.code == bf.code + } + return reflect.DeepEqual(a, b) +} + // wantFrameType indicates that we expect the Conn to send a frame, // although we don't care about the contents. func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) { diff --git a/internal/quic/crypto_stream.go b/internal/quic/crypto_stream.go index 8aa8f7b82..a4dcb32eb 100644 --- a/internal/quic/crypto_stream.go +++ b/internal/quic/crypto_stream.go @@ -30,7 +30,10 @@ type cryptoStream struct { func (s *cryptoStream) handleCrypto(off int64, b []byte, f func([]byte) error) error { end := off + int64(len(b)) if end-s.inset.min() > cryptoBufferSize { - return localTransportError(errCryptoBufferExceeded) + return localTransportError{ + code: errCryptoBufferExceeded, + reason: "crypto buffer exceeded", + } } s.inset.add(off, end) if off == s.in.start { diff --git a/internal/quic/errors.go b/internal/quic/errors.go index 8e01bb7cb..954793cfc 100644 --- a/internal/quic/errors.go +++ b/internal/quic/errors.go @@ -83,10 +83,16 @@ func (e transportError) String() string { } // A localTransportError is an error sent to the peer. -type localTransportError transportError +type localTransportError struct { + code transportError + reason string +} func (e localTransportError) Error() string { - return "closed connection: " + transportError(e).String() + if e.reason == "" { + return fmt.Sprintf("closed connection: %v", e.code) + } + return fmt.Sprintf("closed connection: %v: %q", e.code, e.reason) } // A peerTransportError is an error received from the peer. diff --git a/internal/quic/listener.go b/internal/quic/listener.go index 24484eb6f..8b31dcbe8 100644 --- a/internal/quic/listener.go +++ b/internal/quic/listener.go @@ -107,7 +107,7 @@ func (l *Listener) Close(ctx context.Context) error { if !l.closing { l.closing = true for c := range l.conns { - c.Abort(localTransportError(errNo)) + c.Abort(localTransportError{code: errNo}) } if len(l.conns) == 0 { l.udpConn.Close() diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go index 7b141ac49..1f939f491 100644 --- a/internal/quic/packet_protection.go +++ b/internal/quic/packet_protection.go @@ -441,7 +441,7 @@ func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumbe if err != nil { k.authFailures++ if k.authFailures >= aeadIntegrityLimit(k.r.suite) { - return nil, 0, localTransportError(errAEADLimitReached) + return nil, 0, localTransportError{code: errAEADLimitReached} } return nil, 0, err } diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 89036b19b..58d84ed1b 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -567,19 +567,31 @@ func (s *Stream) handleReset(code uint64, finalSize int64) error { func (s *Stream) checkStreamBounds(end int64, fin bool) error { if end > s.inwin { // The peer sent us data past the maximum flow control window we gave them. - return localTransportError(errFlowControl) + return localTransportError{ + code: errFlowControl, + reason: "stream flow control window exceeded", + } } if s.insize != -1 && end > s.insize { // The peer sent us data past the final size of the stream they previously gave us. - return localTransportError(errFinalSize) + return localTransportError{ + code: errFinalSize, + reason: "data received past end of stream", + } } if fin && s.insize != -1 && end != s.insize { // The peer changed the final size of the stream. - return localTransportError(errFinalSize) + return localTransportError{ + code: errFinalSize, + reason: "final size of stream changed", + } } if fin && end < s.in.end { // The peer has previously sent us data past the final size. - return localTransportError(errFinalSize) + return localTransportError{ + code: errFinalSize, + reason: "end of stream occurs before prior data", + } } return nil } diff --git a/internal/quic/stream_limits.go b/internal/quic/stream_limits.go index 6eda7883b..2f42cf418 100644 --- a/internal/quic/stream_limits.go +++ b/internal/quic/stream_limits.go @@ -66,7 +66,10 @@ func (lim *remoteStreamLimits) init(maxOpen int64) { func (lim *remoteStreamLimits) open(id streamID) error { num := id.num() if num >= lim.max { - return localTransportError(errStreamLimit) + return localTransportError{ + code: errStreamLimit, + reason: "stream limit exceeded", + } } if num >= lim.opened { lim.opened = num + 1 diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 7c1377fae..9bf2b5871 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "io" - "reflect" "strings" "testing" ) @@ -848,7 +847,7 @@ func TestStreamOffsetTooLarge(t *testing.T) { got, _ := tc.readFrame() want1 := debugFrameConnectionCloseTransport{code: errFrameEncoding} want2 := debugFrameConnectionCloseTransport{code: errFlowControl} - if !reflect.DeepEqual(got, want1) && !reflect.DeepEqual(got, want2) { + if !frameEqual(got, want1) && !frameEqual(got, want2) { t.Fatalf("STREAM offset exceeds 2^62-1\ngot: %v\nwant: %v\n or: %v", got, want1, want2) } } diff --git a/internal/quic/transport_params.go b/internal/quic/transport_params.go index dc76d1650..3cc56f4e4 100644 --- a/internal/quic/transport_params.go +++ b/internal/quic/transport_params.go @@ -169,12 +169,12 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) { for len(params) > 0 { id, n := consumeVarint(params) if n < 0 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } params = params[n:] val, n := consumeVarintBytes(params) if n < 0 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } params = params[n:] n = 0 @@ -193,14 +193,14 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) { p.maxIdleTimeout = time.Duration(v) * time.Millisecond case paramStatelessResetToken: if len(val) != 16 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } p.statelessResetToken = val n = 16 case paramMaxUDPPayloadSize: p.maxUDPPayloadSize, n = consumeVarintInt64(val) if p.maxUDPPayloadSize < 1200 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } case paramInitialMaxData: p.initialMaxData, n = consumeVarintInt64(val) @@ -213,32 +213,32 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) { case paramInitialMaxStreamsBidi: p.initialMaxStreamsBidi, n = consumeVarintInt64(val) if p.initialMaxStreamsBidi > maxStreamsLimit { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } case paramInitialMaxStreamsUni: p.initialMaxStreamsUni, n = consumeVarintInt64(val) if p.initialMaxStreamsUni > maxStreamsLimit { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } case paramAckDelayExponent: var v uint64 v, n = consumeVarint(val) if v > 20 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } p.ackDelayExponent = int8(v) case paramMaxAckDelay: var v uint64 v, n = consumeVarint(val) if v >= 1<<14 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } p.maxAckDelay = time.Duration(v) * time.Millisecond case paramDisableActiveMigration: p.disableActiveMigration = true case paramPreferredAddress: if len(val) < 4+2+16+2+1 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } p.preferredAddrV4 = netip.AddrPortFrom( netip.AddrFrom4(*(*[4]byte)(val[:4])), @@ -253,18 +253,18 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) { var nn int p.preferredAddrConnID, nn = consumeUint8Bytes(val) if nn < 0 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } val = val[nn:] if len(val) != 16 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } p.preferredAddrResetToken = val val = nil case paramActiveConnectionIDLimit: p.activeConnIDLimit, n = consumeVarintInt64(val) if p.activeConnIDLimit < 2 { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } case paramInitialSourceConnectionID: p.initialSrcConnID = val @@ -276,7 +276,7 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) { n = len(val) } if n != len(val) { - return p, localTransportError(errTransportParameter) + return p, localTransportError{code: errTransportParameter} } } return p, nil