From 7c0096c86977a6345402445d1048134557c72820 Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Thu, 11 Jan 2024 09:04:21 +0100 Subject: [PATCH] Turn this package into a wrapper for protobuf/encoding/protodelim Since Go Protobuf v1.30.0, the protodelim package is available upstream. The only notable API difference is that protodelim does not return the number of bytes read, which is why I added the countingReader type to pbutil/decode.go. --- go.mod | 2 +- go.sum | 4 +-- pbutil/decode.go | 78 +++++++++++++++++++------------------------ pbutil/decode_test.go | 42 +++++++++++++---------- pbutil/encode.go | 21 ++---------- 5 files changed, 63 insertions(+), 84 deletions(-) diff --git a/go.mod b/go.mod index c6a5186..f746998 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.19 require ( github.com/google/go-cmp v0.5.9 - google.golang.org/protobuf v1.28.1 + google.golang.org/protobuf v1.31.0 ) diff --git a/go.sum b/go.sum index 38fba05..4d0bc04 100644 --- a/go.sum +++ b/go.sum @@ -4,5 +4,5 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/pbutil/decode.go b/pbutil/decode.go index 7c08e56..f389f2e 100644 --- a/pbutil/decode.go +++ b/pbutil/decode.go @@ -15,15 +15,40 @@ package pbutil import ( - "encoding/binary" - "errors" "io" + "google.golang.org/protobuf/encoding/protodelim" "google.golang.org/protobuf/proto" ) -// TODO: Give error package name prefix in next minor release. -var errInvalidVarint = errors.New("invalid varint32 encountered") +type countingReader struct { + r io.Reader + n int +} + +// implements protodelim.Reader +func (r *countingReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if n > 0 { + r.n += n + } + return n, err +} + +// implements protodelim.Reader +func (c *countingReader) ReadByte() (byte, error) { + var buf [1]byte + for { + n, err := c.Read(buf[:]) + if n == 0 && err == nil { + // io.Reader states: Callers should treat a return of 0 and nil as + // indicating that nothing happened; in particular it does not + // indicate EOF. + continue + } + return buf[0], err + } +} // ReadDelimited decodes a message from the provided length-delimited stream, // where the length is encoded as 32-bit varint prefix to the message body. @@ -37,45 +62,10 @@ var errInvalidVarint = errors.New("invalid varint32 encountered") // of the stream has been reached in doing so. In that case, any subsequent // calls return (0, io.EOF). func ReadDelimited(r io.Reader, m proto.Message) (n int, err error) { - // TODO: Consider allowing the caller to specify a decode buffer in the - // next major version. - - // TODO: Consider using error wrapping to annotate error state in pass- - // through cases in the next minor version. - - // Per AbstractParser#parsePartialDelimitedFrom with - // CodedInputStream#readRawVarint32. - var headerBuf [binary.MaxVarintLen32]byte - var bytesRead, varIntBytes int - var messageLength uint64 - for varIntBytes == 0 { // i.e. no varint has been decoded yet. - if bytesRead >= len(headerBuf) { - return bytesRead, errInvalidVarint - } - // We have to read byte by byte here to avoid reading more bytes - // than required. Each read byte is appended to what we have - // read before. - newBytesRead, err := r.Read(headerBuf[bytesRead : bytesRead+1]) - if newBytesRead == 0 { - if err != nil { - return bytesRead, err - } - // A Reader should not return (0, nil); but if it does, it should - // be treated as no-op according to the Reader contract. - continue - } - bytesRead += newBytesRead - // Now present everything read so far to the varint decoder and - // see if a varint can be decoded already. - messageLength, varIntBytes = binary.Uvarint(headerBuf[:bytesRead]) - } - - messageBuf := make([]byte, messageLength) - newBytesRead, err := io.ReadFull(r, messageBuf) - bytesRead += newBytesRead - if err != nil { - return bytesRead, err + cr := &countingReader{r: r} + opts := protodelim.UnmarshalOptions{ + MaxSize: -1, } - - return bytesRead, proto.Unmarshal(messageBuf, m) + err = opts.UnmarshalFrom(cr, m) + return cr.n, err } diff --git a/pbutil/decode_test.go b/pbutil/decode_test.go index fa2972d..7da1b64 100644 --- a/pbutil/decode_test.go +++ b/pbutil/decode_test.go @@ -16,6 +16,7 @@ package pbutil import ( "bytes" + "encoding/binary" "errors" "io" "testing" @@ -29,29 +30,34 @@ import ( func TestReadDelimitedIllegalVarint(t *testing.T) { var tests = []struct { - in []byte - n int - err error + name string + in []byte + n int }{ { - in: []byte{255, 255, 255, 255, 255}, - n: 5, - err: errInvalidVarint, + name: "all 0xFF", + in: []byte{255, 255, 255, 255, 255}, + n: 5, }, + + // Ensure ReadDelimited eventually stops parsing a varint instead of + // looping as long as the input bytes have the continuation bit set. { - in: []byte{255, 255, 255, 255, 255, 255}, - n: 5, - err: errInvalidVarint, + name: "infinite continuation bits", + in: bytes.Repeat([]byte{255}, 2*binary.MaxVarintLen64), + n: binary.MaxVarintLen64, }, } for _, test := range tests { - n, err := ReadDelimited(bytes.NewReader(test.in), nil) - if got, want := n, test.n; !cmp.Equal(got, want) { - t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", test.in, got, want) - } - if got, want := err, test.err; !errors.Is(got, want) { - t.Errorf("ReadDelimited(%#v, nil) = ?, %#v; want = ?, %#v", test.in, got, want) - } + t.Run(test.name, func(t *testing.T) { + n, err := ReadDelimited(bytes.NewReader(test.in), nil) + if got, want := n, test.n; !cmp.Equal(got, want) { + t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", test.in, got, want) + } + if err == nil { + t.Errorf("ReadDelimited(%#v) unexpectedly did not result in an error", test.in) + } + }) } } @@ -61,7 +67,7 @@ func TestReadDelimitedPrematureHeader(t *testing.T) { if got, want := n, 1; !cmp.Equal(got, want) { t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", data[0:1], got, want) } - if got, want := err, io.EOF; !errors.Is(got, want) { + if got, want := err, io.ErrUnexpectedEOF; !errors.Is(got, want) { t.Errorf("ReadDelimited(%#v, nil) = ?, %#v; want = ?, %#v", data[0:1], got, want) } } @@ -83,7 +89,7 @@ func TestReadDelimitedPrematureHeaderIncremental(t *testing.T) { if got, want := n, 1; !cmp.Equal(got, want) { t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", data[0:1], got, want) } - if got, want := err, io.EOF; !errors.Is(got, want) { + if got, want := err, io.ErrUnexpectedEOF; !errors.Is(got, want) { t.Errorf("ReadDelimited(%#v, nil) = ?, %#v; want = ?, %#v", data[0:1], got, want) } } diff --git a/pbutil/encode.go b/pbutil/encode.go index e58dd9d..7ef4678 100644 --- a/pbutil/encode.go +++ b/pbutil/encode.go @@ -15,9 +15,9 @@ package pbutil import ( - "encoding/binary" "io" + "google.golang.org/protobuf/encoding/protodelim" "google.golang.org/protobuf/proto" ) @@ -28,22 +28,5 @@ import ( // number of bytes written and any applicable error. This is roughly // equivalent to the companion Java API's MessageLite#writeDelimitedTo. func WriteDelimited(w io.Writer, m proto.Message) (n int, err error) { - // TODO: Consider allowing the caller to specify an encode buffer in the - // next major version. - - buffer, err := proto.Marshal(m) - if err != nil { - return 0, err - } - - var buf [binary.MaxVarintLen32]byte - encodedLength := binary.PutUvarint(buf[:], uint64(len(buffer))) - - sync, err := w.Write(buf[:encodedLength]) - if err != nil { - return sync, err - } - - n, err = w.Write(buffer) - return n + sync, err + return protodelim.MarshalTo(w, m) }