diff --git a/envelope.go b/envelope.go index 8f81ab03..62680ca9 100644 --- a/envelope.go +++ b/envelope.go @@ -226,7 +226,7 @@ func (r *envelopeReader) Unmarshal(message any) *Error { func (r *envelopeReader) Read(env *envelope) *Error { prefixes := [5]byte{} - prefixBytesRead, err := r.reader.Read(prefixes[:]) + prefixBytesRead, err := io.ReadFull(r.reader, prefixes[:]) switch { case (err == nil || errors.Is(err, io.EOF)) && @@ -240,7 +240,7 @@ func (r *envelopeReader) Read(env *envelope) *Error { // to the user so that they know that the stream has ended. We shouldn't // add any alarming text about protocol errors, though. return NewError(CodeUnknown, err) - case err != nil || prefixBytesRead < 5: + case err != nil: // Something else has gone wrong - the stream didn't end cleanly. if connectErr, ok := asError(err); ok { return connectErr @@ -249,9 +249,6 @@ func (r *envelopeReader) Read(env *envelope) *Error { // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. return maxBytesErr } - if err == nil { - err = io.ErrUnexpectedEOF - } return errorf( CodeInvalidArgument, "protocol error: incomplete envelope: %w", err, diff --git a/envelope_test.go b/envelope_test.go new file mode 100644 index 00000000..bf187934 --- /dev/null +++ b/envelope_test.go @@ -0,0 +1,74 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connect + +import ( + "bytes" + "encoding/binary" + "io" + "testing" + + "connectrpc.com/connect/internal/assert" +) + +func TestEnvelope_read(t *testing.T) { + t.Parallel() + + head := [5]byte{} + payload := []byte(`{"number": 42}`) + binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) + + buf := &bytes.Buffer{} + buf.Write(head[:]) + buf.Write(payload) + + t.Run("full", func(t *testing.T) { + t.Parallel() + env := &envelope{Data: &bytes.Buffer{}} + rdr := envelopeReader{ + reader: bytes.NewReader(buf.Bytes()), + } + assert.Nil(t, rdr.Read(env)) + assert.Equal(t, payload, env.Data.Bytes()) + }) + t.Run("byteByByte", func(t *testing.T) { + t.Parallel() + env := &envelope{Data: &bytes.Buffer{}} + rdr := envelopeReader{ + reader: byteByByteReader{ + reader: bytes.NewReader(buf.Bytes()), + }, + } + assert.Nil(t, rdr.Read(env)) + assert.Equal(t, payload, env.Data.Bytes()) + }) +} + +// byteByByteReader is test reader that reads a single byte at a time. +type byteByByteReader struct { + reader io.ByteReader +} + +func (b byteByByteReader) Read(data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } + next, err := b.reader.ReadByte() + if err != nil { + return 0, err + } + data[0] = next + return 1, nil +}