diff --git a/conn.go b/conn.go index cd749784c..338f793ad 100644 --- a/conn.go +++ b/conn.go @@ -881,7 +881,11 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh } } else { switch { - case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF): + case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): + case errors.Is(err, recordlayer.ErrInvalidPacketLength): + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + continue default: if c.isHandshakeCompletedSuccessfully() { // Keep read loop and pass the read error to Read() diff --git a/conn_test.go b/conn_test.go index 946e4ab43..ea3c842f7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -401,6 +401,71 @@ func TestHandshakeWithAlert(t *testing.T) { } } +func TestHandshakeWithInvalidRecord(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientErr := make(chan result, 1) + ca, cb := dpipe.Pipe() + caWithInvalidRecord := &connWithCallback{Conn: ca} + + var msgSeq atomic.Int32 + // Send invalid record after first message + caWithInvalidRecord.onWrite = func(b []byte) { + if msgSeq.Add(1) == 2 { + if _, err := ca.Write([]byte{0x01, 0x02}); err != nil { + t.Fatal(err) + } + } + } + go func() { + client, err := testClient(ctx, caWithInvalidRecord, &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + }, true) + clientErr <- result{client, err} + }() + + server, errServer := testServer(ctx, cb, &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + }, true) + + errClient := <-clientErr + + defer func() { + if server != nil { + if err := server.Close(); err != nil { + t.Fatal(err) + } + } + + if errClient.c != nil { + if err := errClient.c.Close(); err != nil { + t.Fatal(err) + } + } + }() + + if errServer != nil { + t.Fatalf("Server failed(%v)", errServer) + } + + if errClient.err != nil { + t.Fatalf("Client failed(%v)", errClient.err) + } +} + func TestExportKeyingMaterial(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) @@ -2973,3 +3038,15 @@ func TestSkipHelloVerify(t *testing.T) { t.Error(err) } } + +type connWithCallback struct { + net.Conn + onWrite func([]byte) +} + +func (c *connWithCallback) Write(b []byte) (int, error) { + if c.onWrite != nil { + c.onWrite(b) + } + return c.Conn.Write(b) +} diff --git a/pkg/protocol/recordlayer/errors.go b/pkg/protocol/recordlayer/errors.go index cd4cb60a5..1c1898844 100644 --- a/pkg/protocol/recordlayer/errors.go +++ b/pkg/protocol/recordlayer/errors.go @@ -11,9 +11,11 @@ import ( ) var ( - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113 - errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 - errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 - errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 + // ErrInvalidPacketLength is returned when the packet length too small or declared length do not match + ErrInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113 + + errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 + errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 + errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 + errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 ) diff --git a/pkg/protocol/recordlayer/recordlayer.go b/pkg/protocol/recordlayer/recordlayer.go index 02325fd2d..a570f745f 100644 --- a/pkg/protocol/recordlayer/recordlayer.go +++ b/pkg/protocol/recordlayer/recordlayer.go @@ -86,12 +86,12 @@ func UnpackDatagram(buf []byte) ([][]byte, error) { for offset := 0; len(buf) != offset; { if len(buf)-offset <= HeaderSize { - return nil, errInvalidPacketLength + return nil, ErrInvalidPacketLength } pktLen := (HeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:]))) if offset+pktLen > len(buf) { - return nil, errInvalidPacketLength + return nil, ErrInvalidPacketLength } out = append(out, buf[offset:offset+pktLen]) diff --git a/pkg/protocol/recordlayer/recordlayer_test.go b/pkg/protocol/recordlayer/recordlayer_test.go index 2e16c0104..760a73040 100644 --- a/pkg/protocol/recordlayer/recordlayer_test.go +++ b/pkg/protocol/recordlayer/recordlayer_test.go @@ -39,12 +39,12 @@ func TestUDPDecode(t *testing.T) { { Name: "Invalid packet length", Data: []byte{0x14, 0xfe}, - WantError: errInvalidPacketLength, + WantError: ErrInvalidPacketLength, }, { Name: "Packet declared invalid length", Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01}, - WantError: errInvalidPacketLength, + WantError: ErrInvalidPacketLength, }, } { dtlsPkts, err := UnpackDatagram(test.Data)